File tree Expand file tree Collapse file tree 5 files changed +8
-23
lines changed
models/qwen3_moe/layer_infer Expand file tree Collapse file tree 5 files changed +8
-23
lines changed Original file line number Diff line number Diff line change @@ -516,8 +516,7 @@ def grouped_matmul(
516516 if block_size_k != 0 :
517517 # 如果使用了 block wise 量化,分块大小不能超过 block size
518518 BLOCK_SIZE_K = min (BLOCK_SIZE_K , block_size_k )
519- BLOCK_SIZE_K = triton .next_power_of_2 (BLOCK_SIZE_K // 2 + 1 )
520- # assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)
519+ assert BLOCK_SIZE_K == triton .next_power_of_2 (BLOCK_SIZE_K )
521520
522521 if use_fp8_w8a8 :
523522 # 当权重使用 block wise 量化时,激活也使用 per token, group size 量化
Original file line number Diff line number Diff line change @@ -42,12 +42,12 @@ def try_to_get_best_config(
4242 else :
4343 if M <= expert_num :
4444 config = {
45- "BLOCK_SIZE_M" : 32 ,
46- "BLOCK_SIZE_N" : 128 ,
47- "BLOCK_SIZE_K" : 128 ,
48- "GROUP_SIZE_M" : 32 ,
45+ "BLOCK_SIZE_M" : 16 ,
46+ "BLOCK_SIZE_N" : 32 ,
47+ "BLOCK_SIZE_K" : 64 ,
48+ "GROUP_SIZE_M" : 1 ,
4949 "num_warps" : 4 ,
50- "num_stages" : 3 ,
50+ "num_stages" : 1 ,
5151 }
5252 else :
5353 config = {
Original file line number Diff line number Diff line change @@ -111,7 +111,7 @@ def _moe_ffn_edp(
111111 ep_output = layer_weight .experts .experts (
112112 hidden_states ,
113113 router_logits = router_logits ,
114- top_k = 8 ,
114+ top_k = self . num_experts_per_tok ,
115115 renormalize = self .norm_topk_prob ,
116116 use_grouped_topk = False ,
117117 topk_group = None ,
Original file line number Diff line number Diff line change @@ -26,9 +26,7 @@ def get_unique_server_name():
2626def set_cuda_arch (args ):
2727 if not torch .cuda .is_available ():
2828 return
29- from lightllm .utils .sgl_utils import HAS_FLASHINFER
30-
31- if HAS_FLASHINFER :
29+ if args .enable_flashinfer_prefill or args .enable_flashinfer_decode :
3230 capability = torch .cuda .get_device_capability ()
3331 arch = f"{ capability [0 ]} .{ capability [1 ]} "
3432 os .environ ["TORCH_CUDA_ARCH_LIST" ] = f"{ arch } { '+PTX' if arch == '9.0' else '' } "
Original file line number Diff line number Diff line change 3030 "sgl_kernel is not installed, or the installed version did not support fa3. \
3131 Try to upgrade it."
3232 )
33-
34- try :
35- import flashinfer
36- from flashinfer .norm import fused_add_rmsnorm , rmsnorm
37-
38- HAS_FLASHINFER = True
39- except :
40- HAS_FLASHINFER = False
41- logger .warning (
42- "flashinfer is not installed, you can't use the api of it. \
43- You can solve it by running `pip install flashinfer`."
44- )
You can’t perform that action at this time.
0 commit comments