Skip to content

Commit 26a121b

Browse files
committed
fix
1 parent b4afd43 commit 26a121b

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ def decorator(fn):
4141

4242

4343
class Autotuner:
44-
@staticmethod
45-
def _get_param_names(func: Callable) -> List[str]:
46-
sig = inspect.signature(func)
47-
return [name for name, p in sig.parameters.items()]
48-
4944
def __init__(
5045
self,
5146
fn,
@@ -73,26 +68,27 @@ def __init__(
7368
self.run_key_func = run_key_func
7469
self.run_key_distance_func = run_key_distance_func
7570
self.cached_configs = {}
76-
self.arg_names = [param.name for param in inspect.signature(fn).parameters.values()]
71+
self.arg_names = [param.name for param in inspect.signature(self.fn).parameters.values()]
7772
self._argname_to_pos = {name: idx for idx, name in enumerate(self.arg_names)}
7873
self._pos_to_argname = {idx: name for idx, name in enumerate(self.arg_names)}
7974

80-
self._static_key_func_param_names = self._get_param_names(self.static_key_func)
81-
self._run_key_func_param_names = self._get_param_names(self.run_key_func)
82-
75+
self._static_key_func_param_names = [
76+
name for name, _ in inspect.signature(self.static_key_func).parameters.items()
77+
]
78+
self._run_key_func_param_names = [name for name, _ in inspect.signature(self.run_key_func).parameters.items()]
8379
self.mutates_args = mutates_args
84-
85-
if not os.path.exists(self.cache_dir):
86-
if is_triton_autotune_enabled():
87-
os.makedirs(self.cache_dir, exist_ok=True)
80+
return
8881

8982
def _try_load_cache(self, static_key):
9083
if static_key in self.cached_configs:
9184
return
85+
86+
os.makedirs(self.cache_dir, exist_ok=True)
9287
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
9388
if os.path.exists(cache_file):
9489
with open(cache_file, "rb") as f:
9590
self.cached_configs[static_key] = orjson.loads(f.read())
91+
return
9692

9793
def _bench(self, *args, n_repeat=5, n_retries=1, **kwargs):
9894
from triton.compiler.errors import CompileTimeAssertionFailure

0 commit comments

Comments
 (0)