Skip to content

Commit 3c70cb9

Browse files
committed
fix
1 parent 202378f commit 3c70cb9

File tree

6 files changed

+28
-52
lines changed

6 files changed

+28
-52
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def _get_grouped_matmul_configs():
490490

491491
@autotune(
492492
name="grouped_matmul:v1",
493-
configs=_get_grouped_matmul_configs,
493+
configs_gen_func=_get_grouped_matmul_configs,
494494
static_key_func=_get_grouped_matmul_static_key,
495495
run_key_func=lambda token_num_mul_topk_num: token_num_mul_topk_num,
496496
)

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor):
7979

8080
@autotune(
8181
name="silu_and_mul_fwd:v1",
82-
configs=_get_silu_and_mul_configs,
82+
configs_gen_func=_get_silu_and_mul_configs,
8383
static_key_func=_get_silu_and_mul_static_key,
8484
run_key_func=lambda input: input.shape[0],
8585
)

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _get_moe_sum_reduce_configs():
6464

6565
@autotune(
6666
name="moe_sum_reduce:v1",
67-
configs=_get_moe_sum_reduce_configs,
67+
configs_gen_func=_get_moe_sum_reduce_configs,
6868
static_key_func=_get_moe_sum_reduce_static_key,
6969
run_key_func=lambda input: input.shape[0],
7070
)

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _get_static_key(A, B, block_size, dtype):
178178

179179
@autotune(
180180
name="w8a8_block_fp8_matmul:v1",
181-
configs=get_test_configs,
181+
configs_gen_func=get_test_configs,
182182
static_key_func=_get_static_key,
183183
run_key_func=lambda M: M,
184184
)

lightllm/common/triton_utils/autotuner.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,23 @@
2323

2424
def autotune(
2525
name: str,
26-
configs: "Optional[Union[List, Callable[[], List]]]",
26+
configs_gen_func: Callable[[], List],
2727
static_key_func: "Optional[Callable]" = None,
2828
run_key_func: "Optional[Callable]" = None,
2929
run_key_distance_func: "Optional[Callable]" = None,
30-
reset_to_zero=None,
31-
restore_value=None,
32-
pre_hook: "Optional[Callable]" = None,
33-
post_hook: "Optional[Callable]" = None,
30+
mutates_args: List[str] = None,
3431
):
3532
def decorator(fn):
3633
arg_names = [param.name for param in inspect.signature(fn).parameters.values()]
3734
return Autotuner(
3835
fn,
3936
arg_names,
4037
name,
41-
configs,
38+
configs_gen_func,
4239
static_key_func,
4340
run_key_func,
4441
run_key_distance_func,
45-
reset_to_zero,
46-
restore_value,
47-
pre_hook=pre_hook,
48-
post_hook=post_hook,
42+
mutates_args,
4943
)
5044

5145
return decorator
@@ -76,21 +70,18 @@ def __init__(
7670
fn,
7771
arg_names,
7872
name,
79-
configs: "Optional[Union[List, Callable[[], List]]]",
80-
static_key_func: "Optional[Callable[[], dict]]" = None,
81-
run_key_func: "Optional[Callable]" = None,
82-
run_key_distance_func: "Optional[Callable]" = lambda a, b: abs(a - b),
83-
reset_to_zero: "Optional[List]" = None,
84-
restore_value: "Optional[List]" = None,
85-
pre_hook: "Optional[Callable]" = None,
86-
post_hook: "Optional[Callable]" = None,
73+
configs_gen_func: Callable[[], List],
74+
static_key_func: Optional[Callable[[], dict]] = None,
75+
run_key_func: Optional[Callable] = None,
76+
run_key_distance_func: Optional[Callable] = lambda run_key, run_key_cached: abs(run_key - run_key_cached),
77+
mutates_args: List[str] = None,
8778
):
8879
# Whether to use this autotune decorator
8980
self.disable_autotune = os.environ.get("DISABLE_AUTOTUNE_DECORATOR", "0") == "1"
9081

91-
self.configs: "Optional[Union[List, Callable[[], List]]]" = configs
92-
if not callable(self.configs):
93-
self.configs = split_configs(self.configs)
82+
self.configs = None
83+
self.configs_gen_func = configs_gen_func
84+
9485
self.name = name
9586
self.cache_dir = os.path.join(
9687
Path(__file__).parent, "all_kernel_configs", get_triton_version(), get_current_device_name(), self.name
@@ -106,34 +97,21 @@ def __init__(
10697
self._static_param_names = self._get_param_names(self.static_key_func)
10798
self._run_param_names = self._get_param_names(self.run_key_func)
10899

109-
self.reset_to_zero = []
110-
if reset_to_zero is not None:
111-
self.reset_to_zero = list(reset_to_zero)
112-
self.restore_value = []
113-
if restore_value is not None:
114-
self.restore_value = list(restore_value)
100+
self.mutates_args = []
101+
if mutates_args is not None:
102+
self.mutates_args = list(mutates_args)
115103

116104
self.pre_hook = lambda kwargs, reset_only=False: 0
117105
self.post_hook = lambda kwargs, exception: 0
118106
self.user_defined_pre_hook = False
119107
self.user_defined_post_hook = False
120-
if pre_hook:
121-
self.pre_hook = pre_hook
122-
self.user_defined_pre_hook = True
123-
elif len(self.reset_to_zero) > 0 or len(self.restore_value) > 0:
124108

125-
def _pre_hook(kwargs, reset_only=False):
126-
for name in self.reset_to_zero:
127-
kwargs[name].zero_()
128-
if not reset_only:
129-
self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
109+
if len(self.mutates_args) > 0:
130110

131-
self.pre_hook = _pre_hook
111+
def _pre_hook(kwargs):
112+
self.restore_copies = {name: kwargs[name].clone() for name in self.mutates_args}
132113

133-
if post_hook:
134-
self.post_hook = post_hook
135-
self.user_defined_post_hook = True
136-
elif len(self.restore_value) > 0:
114+
self.pre_hook = _pre_hook
137115

138116
def _post_hook(kwargs, exception):
139117
for name in self.restore_value:
@@ -143,7 +121,7 @@ def _post_hook(kwargs, exception):
143121
self.post_hook = _post_hook
144122

145123
if not os.path.exists(self.cache_dir):
146-
if os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1":
124+
if is_triton_autotune_enabled():
147125
os.makedirs(self.cache_dir, exist_ok=True)
148126

149127
self._loaded_static_keys = set()
@@ -207,11 +185,11 @@ def kernel_call():
207185
def _autotune(self, args, kwargs, static_key, run_key):
208186
from lightllm.utils.dist_utils import get_global_rank
209187

210-
if callable(self.configs):
211-
self.configs = split_configs(self.configs())
188+
if self.configs is None:
189+
self.configs = split_configs(self.configs_gen_func())
212190

213191
rank_id = get_global_rank()
214-
_best_config = self.default_config
192+
_best_config = None
215193
best_time = float("inf")
216194

217195
bar = tqdm(
@@ -261,8 +239,6 @@ def _autotune(self, args, kwargs, static_key, run_key):
261239
logger.info(f"Saved configs for {self.name} - {static_key} - {run_key}")
262240

263241
kwargs["run_config"] = self.cached_configs[static_key][run_key]
264-
full_nargs = {**self.nargs, **kwargs}
265-
self.pre_hook(full_nargs, reset_only=True)
266242

267243
@torch.no_grad()
268244
def __call__(self, *args, **kwargs):

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_static_key(q, k):
120120

121121
@autotune(
122122
name="rotary_emb_fwd:v1",
123-
configs=get_test_configs,
123+
configs_gen_func=get_test_configs,
124124
static_key_func=get_static_key,
125125
run_key_func=lambda q: q.shape[0],
126126
)

0 commit comments

Comments
 (0)