Skip to content

Commit e3a44a1

Browse files
committed
clean code
1 parent 07838f1 commit e3a44a1

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
)
2222
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
2323
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
24-
from lightllm.utils.envs_utils import get_triton_autotune_level
25-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
24+
from lightllm.utils.envs_utils import enable_triton_autotune
2625
from lightllm.utils.log_utils import init_logger
2726

2827
logger = init_logger(__name__)
@@ -356,7 +355,7 @@ def prefilled_group_gemm(
356355
######################################## warning ##################################################
357356
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
358357
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
359-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
358+
if enable_triton_autotune():
360359
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
361360
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
362361
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17-
from lightllm.utils.envs_utils import get_triton_autotune_level
18-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
17+
from lightllm.utils.envs_utils import enable_triton_autotune
1918
import numpy as np
2019

2120
logger = init_logger(__name__)
@@ -190,7 +189,7 @@ def fused_experts_impl(
190189
######################################## warning ##################################################
191190
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
192191
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
193-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
192+
if enable_triton_autotune():
194193
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
195194
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
196195
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/topk_select.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
2525
from lightllm.common.fused_moe.softmax_topk import softmax_topk
26-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
27-
from lightllm.utils.envs_utils import get_triton_autotune_level
26+
from lightllm.utils.envs_utils import enable_triton_autotune
2827

2928
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
3029

@@ -225,7 +224,7 @@ def select_experts(
225224

226225
######################################## warning ##################################################
227226
# here is used to match autotune feature, make topk_ids more random
228-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
227+
if enable_triton_autotune():
229228
rand_gen = torch.Generator(device="cuda")
230229
rand_gen.manual_seed(router_logits.shape[0])
231230
router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda")

lightllm/common/triton_utils/autotuner.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightllm.utils.device_utils import get_current_device_name
1313
from lightllm.utils.log_utils import init_logger
1414
from typing import Callable, Optional, Union, List
15-
from lightllm.utils.envs_utils import get_triton_autotune_level
15+
from lightllm.utils.envs_utils import enable_triton_autotune, get_triton_autotune_level
1616
from lightllm.common.kernel_config import KernelConfigs
1717
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
1818

@@ -120,6 +120,7 @@ def __init__(
120120
self.run_key_distance_func = run_key_distance_func
121121
self.cached_configs = {}
122122
self.fast_match_configs = collections.defaultdict(dict)
123+
self.warmuped_configs_set = set()
123124
self.arg_names = [param.name for param in inspect.signature(self.fn).parameters.values()]
124125
self._argname_to_pos = {name: idx for idx, name in enumerate(self.arg_names)}
125126
self._pos_to_argname = {idx: name for idx, name in enumerate(self.arg_names)}
@@ -160,7 +161,7 @@ def __call__(self, *args, **kwargs):
160161
for run_config in all_configs.values():
161162
# warmup all configs
162163
kwargs["run_config"] = run_config
163-
self.kernel_warmup(*args, **kwargs)
164+
self.kernel_warmup(static_key, *args, **kwargs)
164165

165166
if static_key not in self.cached_configs and autotune_level == AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG:
166167
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
@@ -171,10 +172,7 @@ def __call__(self, *args, **kwargs):
171172
)
172173
self.cached_configs[static_key] = {}
173174

174-
if (
175-
autotune_level in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]
176-
and Autotuner.is_autotune_warmup()
177-
):
175+
if enable_triton_autotune():
178176
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (
179177
run_key not in self.cached_configs.get(static_key, {})
180178
)
@@ -218,11 +216,15 @@ def _try_load_cache(self, static_key):
218216
self.cached_configs[static_key] = orjson.loads(f.read())
219217
return True
220218

221-
def kernel_warmup(self, *args, **kwargs):
219+
def kernel_warmup(self, static_key, *args, **kwargs):
222220
new_args, new_kwargs, origin_list, new_list = self._mutate_args_clone(args, kwargs)
223-
221+
run_config = kwargs.get("run_config", {})
222+
hash_key = str(frozendict(run_config)) + str(static_key)
223+
if hash_key in self.warmuped_configs_set:
224+
return
224225
try:
225226
self.fn(*new_args, **new_kwargs)
227+
self.warmuped_configs_set.add(hash_key)
226228
except:
227229
pass
228230
finally:

lightllm/utils/envs_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def get_triton_autotune_level():
154154
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
155155

156156

157+
def enable_triton_autotune():
158+
from lightllm.common.triton_utils.autotuner import AutotuneLevel, Autotuner
159+
160+
return (
161+
get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]
162+
and Autotuner.is_autotune_warmup()
163+
)
164+
165+
157166
g_model_init_done = False
158167

159168

0 commit comments

Comments
 (0)