3434from .moe_silu_and_mul import silu_and_mul_fwd
3535from .moe_sum_reduce import moe_sum_reduce
3636from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
37- from lightllm .utils .dist_utils import get_current_rank_in_dp
3837
3938FFN_MOE_CHUNK_SIZE = 8 * 1024
4039
@@ -221,13 +220,8 @@ def moe_align1(
221220@triton .jit
222221def moe_align2_kernel (
223222 experts_token_num_ptr , # [expert_num,]
224- expert_to_token_index_ptr , # [expert_num, token_num * topk_num]
225- expert_to_token_index_stride_0 ,
226- expert_to_weights_ptr ,
227- expert_to_weights_stride_0 ,
228- mblocks_to_expert_id_ptr , # [max_num_m_blocks,]
229- padded_expert_to_token_index_ptr ,
230- padded_expert_to_weights_ptr ,
223+ mblocks_to_expert_id , # [max_num_m_blocks,]
224+ mblocks_to_m_index , # [max_num_m_blocks,]
231225 expert_num ,
232226 max_num_m_blocks ,
233227 BLOCK_M : tl .constexpr ,
@@ -247,70 +241,42 @@ def moe_align2_kernel(
247241 block_off = tl .arange (0 , 128 )
248242 for start_loc in range (0 , cur_block_num , 128 ):
249243 tl .store (
250- mblocks_to_expert_id_ptr + block_start + start_loc + block_off ,
244+ mblocks_to_expert_id + block_start + start_loc + block_off ,
251245 expert_id ,
252246 mask = start_loc + block_off < cur_block_num ,
253247 )
254-
255- cur_expert_to_token_index_ptr = expert_to_token_index_ptr + expert_id * expert_to_token_index_stride_0
256- for start_loc in range (0 , cur_block_num ):
257- offset = start_loc * BLOCK_M + tl .arange (0 , BLOCK_M )
258- m_index = tl .load (cur_expert_to_token_index_ptr + offset , mask = offset < cur_expert_token_num , other = 0 )
259- tl .store (
260- padded_expert_to_token_index_ptr + block_start * BLOCK_M + offset ,
261- m_index ,
262- mask = offset < cur_expert_token_num ,
263- )
264-
265- m_weight = tl .load (
266- expert_to_weights_ptr + expert_id * expert_to_weights_stride_0 + offset ,
267- mask = offset < cur_expert_token_num ,
268- other = 0.0 ,
269- )
270248 tl .store (
271- padded_expert_to_weights_ptr + block_start * BLOCK_M + offset ,
272- m_weight ,
273- mask = offset < cur_expert_token_num ,
249+ mblocks_to_m_index + block_start + start_loc + block_off ,
250+ start_loc + block_off ,
251+ mask = start_loc + block_off < cur_block_num ,
274252 )
275253
276254 if expert_id == expert_num - 1 :
277255 for extra_fill_start in range (block_start + cur_block_num , max_num_m_blocks , 128 ):
278256 tl .store (
279- mblocks_to_expert_id_ptr + extra_fill_start + block_off ,
257+ mblocks_to_expert_id + extra_fill_start + block_off ,
280258 - 1 ,
281259 mask = extra_fill_start + block_off < max_num_m_blocks ,
282260 )
283261 return
284262
285263
286- def moe_align2 (
287- token_num_mul_topk_num : int ,
288- exports_token_num : torch .Tensor ,
289- block_m : int ,
290- expert_to_token_index : torch .Tensor ,
291- expert_to_weights : torch .Tensor ,
292- ):
264+ def moe_align2 (token_num_mul_topk_num : int , exports_token_num : torch .Tensor , block_m : int ):
293265 """
294266 exports_token_num is tensor shape [expert_num] , will get expert need handle token num.
295267 out tensor is a tensor that contain block schduel infos tensor.
296268 """
297269 max_num_tokens_padded = token_num_mul_topk_num + exports_token_num .shape [0 ] * (block_m - 1 )
298270 max_num_m_blocks = triton .cdiv (max_num_tokens_padded , block_m )
299271 mblocks_to_expert_id = torch .empty ((max_num_m_blocks ,), dtype = torch .int32 , device = "cuda" )
300- padded_expert_to_token_index = torch .empty (max_num_tokens_padded , dtype = torch .int32 , device = "cuda" ).fill_ (- 1 )
301- padded_expert_to_weights = torch .empty (max_num_tokens_padded , dtype = torch .float32 , device = "cuda" )
272+ mblocks_to_m_index = torch .empty ((max_num_m_blocks ,), dtype = torch .int32 , device = "cuda" )
302273 expert_num = exports_token_num .shape [0 ]
303274
304275 grid = (expert_num ,)
305276 moe_align2_kernel [grid ](
306277 exports_token_num ,
307- expert_to_token_index ,
308- expert_to_token_index .stride (0 ),
309- expert_to_weights ,
310- expert_to_weights .stride (0 ),
311278 mblocks_to_expert_id ,
312- padded_expert_to_token_index ,
313- padded_expert_to_weights ,
279+ mblocks_to_m_index ,
314280 expert_num ,
315281 max_num_m_blocks ,
316282 BLOCK_M = block_m ,
@@ -319,14 +285,13 @@ def moe_align2(
319285 num_stages = 1 ,
320286 )
321287
322- return mblocks_to_expert_id , padded_expert_to_token_index , padded_expert_to_weights
288+ return mblocks_to_expert_id , mblocks_to_m_index
323289
324290
325291@triton .jit
326292def grouped_matmul_kernel (
327293 mblocks_to_expert_id , # [max_m_block_size]
328- padded_expert_to_token_index , # [max_m_block_size]
329- padded_expert_to_weights , # [max_m_block_size]
294+ mblocks_to_m_index , # [max_m_block_size]
330295 k , # int
331296 n , # int
332297 topk_num , # int
@@ -342,7 +307,12 @@ def grouped_matmul_kernel(
342307 weight_stride_0 ,
343308 weight_stride_1 ,
344309 weight_stride_2 ,
310+ expert_to_weights_ptr , # [expert_num, token_num * topk]
311+ expert_to_weights_stride0 ,
312+ expert_to_weights_stride1 ,
345313 expert_to_token_num , # [expert_num]
314+ expert_to_token_index , # [expert_num, token_num * topk_num]
315+ expert_to_token_index_stride_0 ,
346316 out_ptr , # [token_num * topk_num, n]
347317 out_stride_0 ,
348318 out_stride_1 ,
@@ -380,14 +350,28 @@ def grouped_matmul_kernel(
380350
381351 if expert_id == - 1 :
382352 return
353+
354+ tile_m_idx = tl .load (mblocks_to_m_index + pid_m )
383355 tile_n_idx = pid_n
356+
357+ # get the gemm size of the current problem
358+ cur_m = tl .load (expert_to_token_num + expert_id , eviction_policy = "evict_last" )
359+
384360 # do regular gemm here
385- offs_am = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
386- # token_mask = offs_am < cur_m
361+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
362+ token_mask = offs_am < cur_m
387363 a_m_index = tl .load (
388- padded_expert_to_token_index + offs_am ,
364+ expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am ,
365+ mask = token_mask ,
366+ other = 0 ,
389367 )
390- token_mask = a_m_index != - 1
368+ if MUL_ROUTED_WEIGHT :
369+ a_m_scale = tl .load (
370+ expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am ,
371+ mask = token_mask ,
372+ other = 0.0 ,
373+ )
374+
391375 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % n
392376 offs_k = tl .arange (0 , BLOCK_SIZE_K )
393377
@@ -453,11 +437,6 @@ def grouped_matmul_kernel(
453437 accumulator *= ab_scale
454438
455439 if MUL_ROUTED_WEIGHT :
456- a_m_scale = tl .load (
457- padded_expert_to_weights + offs_am ,
458- mask = token_mask ,
459- other = 0.0 ,
460- )
461440 accumulator *= a_m_scale [:, None ]
462441
463442 c = accumulator .to (compute_type )
@@ -551,22 +530,16 @@ def grouped_matmul(
551530 token_inputs , token_input_scale = qinput_tensor , input_scale
552531
553532 if reused_mblock_infos is None :
554- mblocks_to_expert_id , padded_expert_to_token_index , padded_expert_to_weights = moe_align2 (
555- token_num_mul_topk_num , expert_to_token_num , BLOCK_SIZE_M , expert_to_token_index , expert_to_weights
556- )
533+ mblocks_to_expert_id , mblocks_to_m_index = moe_align2 (token_num_mul_topk_num , expert_to_token_num , BLOCK_SIZE_M )
557534 else :
558535 # when up group gemm and down group gemm use same BLOCK_SIZE_M,
559536 # can reuse (mblocks_to_expert_id, mblocks_to_m_index) created by moe_align2 kernel.
560- (
561- mblocks_to_expert_id ,
562- padded_expert_to_token_index ,
563- padded_expert_to_weights ,
564- reused_block_size_m ,
565- ) = reused_mblock_infos
537+ mblocks_to_expert_id , mblocks_to_m_index , reused_block_size_m = reused_mblock_infos
566538 if reused_block_size_m != BLOCK_SIZE_M :
567- mblocks_to_expert_id , padded_expert_to_token_index , padded_expert_to_weights = moe_align2 (
568- token_num_mul_topk_num , expert_to_token_num , BLOCK_SIZE_M , expert_to_token_index , expert_to_weights
539+ mblocks_to_expert_id , mblocks_to_m_index = moe_align2 (
540+ token_num_mul_topk_num , expert_to_token_num , BLOCK_SIZE_M
569541 )
542+
570543 block_num = triton .cdiv (n , BLOCK_SIZE_N ) * mblocks_to_expert_id .shape [0 ]
571544
572545 grid = (block_num ,)
@@ -575,8 +548,7 @@ def grouped_matmul(
575548
576549 grouped_matmul_kernel [grid ](
577550 mblocks_to_expert_id ,
578- padded_expert_to_token_index ,
579- padded_expert_to_weights ,
551+ mblocks_to_m_index ,
580552 k ,
581553 n ,
582554 topk_num ,
@@ -598,7 +570,12 @@ def grouped_matmul(
598570 expert_weights .stride (0 ),
599571 expert_weights .stride (1 ),
600572 expert_weights .stride (2 ),
573+ expert_to_weights ,
574+ expert_to_weights .stride (0 ),
575+ expert_to_weights .stride (1 ),
601576 expert_to_token_num ,
577+ expert_to_token_index ,
578+ expert_to_token_index .stride (0 ),
602579 out ,
603580 out .stride (0 ),
604581 out .stride (1 ),
@@ -617,7 +594,7 @@ def grouped_matmul(
617594 num_warps = num_warps ,
618595 num_stages = num_stages ,
619596 )
620- return (mblocks_to_expert_id , padded_expert_to_token_index , padded_expert_to_weights , BLOCK_SIZE_M )
597+ return (mblocks_to_expert_id , mblocks_to_m_index , BLOCK_SIZE_M )
621598
622599
623600def fused_experts_impl (
@@ -648,6 +625,7 @@ def fused_experts_impl(
648625 CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
649626 topk_num = topk_ids .shape [1 ]
650627 M = min (num_tokens , CHUNK_SIZE )
628+
651629 intermediate_cache1 = alloc_tensor_func ((M , topk_num , N ), device = hidden_states .device , dtype = hidden_states .dtype )
652630 intermediate_cache2 = alloc_tensor_func (
653631 (M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype
0 commit comments