Skip to content

Commit 263b53a

Browse files
committed
fix
1 parent e3a44a1 commit 263b53a

File tree

5 files changed

+9
-17
lines changed

5 files changed

+9
-17
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
)
2222
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
2323
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
24-
from lightllm.utils.envs_utils import enable_triton_autotune
2524
from lightllm.utils.log_utils import init_logger
25+
from lightllm.common.triton_utils.autotuner import Autotuner
26+
2627

2728
logger = init_logger(__name__)
2829

@@ -355,7 +356,7 @@ def prefilled_group_gemm(
355356
######################################## warning ##################################################
356357
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
357358
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
358-
if enable_triton_autotune():
359+
if Autotuner.is_autotune_warmup():
359360
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
360361
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
361362
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17-
from lightllm.utils.envs_utils import enable_triton_autotune
17+
from lightllm.common.triton_utils.autotuner import Autotuner
1818
import numpy as np
1919

2020
logger = init_logger(__name__)
@@ -189,7 +189,7 @@ def fused_experts_impl(
189189
######################################## warning ##################################################
190190
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
191191
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
192-
if enable_triton_autotune():
192+
if Autotuner.is_autotune_warmup():
193193
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
194194
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
195195
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/topk_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
2525
from lightllm.common.fused_moe.softmax_topk import softmax_topk
26-
from lightllm.utils.envs_utils import enable_triton_autotune
26+
from lightllm.common.triton_utils.autotuner import Autotuner
2727

2828
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
2929

@@ -224,7 +224,7 @@ def select_experts(
224224

225225
######################################## warning ##################################################
226226
# here is used to match autotune feature, make topk_ids more random
227-
if enable_triton_autotune():
227+
if Autotuner.is_autotune_warmup():
228228
rand_gen = torch.Generator(device="cuda")
229229
rand_gen.manual_seed(router_logits.shape[0])
230230
router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda")

lightllm/common/triton_utils/autotuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightllm.utils.device_utils import get_current_device_name
1313
from lightllm.utils.log_utils import init_logger
1414
from typing import Callable, Optional, Union, List
15-
from lightllm.utils.envs_utils import enable_triton_autotune, get_triton_autotune_level
15+
from lightllm.utils.envs_utils import get_triton_autotune_level
1616
from lightllm.common.kernel_config import KernelConfigs
1717
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
1818

@@ -172,7 +172,7 @@ def __call__(self, *args, **kwargs):
172172
)
173173
self.cached_configs[static_key] = {}
174174

175-
if enable_triton_autotune():
175+
if Autotuner.is_autotune_warmup():
176176
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (
177177
run_key not in self.cached_configs.get(static_key, {})
178178
)

lightllm/utils/envs_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,6 @@ def get_triton_autotune_level():
154154
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
155155

156156

157-
def enable_triton_autotune():
158-
from lightllm.common.triton_utils.autotuner import AutotuneLevel, Autotuner
159-
160-
return (
161-
get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]
162-
and Autotuner.is_autotune_warmup()
163-
)
164-
165-
166157
g_model_init_done = False
167158

168159

0 commit comments

Comments
 (0)