@@ -457,22 +457,23 @@ def grouped_matmul_kernel(
457457 tile_m_idx = tl .load (mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1 )
458458 tile_n_idx = pid_n
459459
460- if OUT_SORTED or TOKEN_INPUT_USE_TMA :
461- assert OUT_SORTED and TOKEN_INPUT_USE_TMA is False
462- # get token start index in inputs
463- token_start_index = tl .load (mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2 )
460+ # get token start index in inputs
461+ token_start_index = tl .load (mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2 )
464462
465463 # get the gemm size of the current problem
466464 cur_m = tl .load (expert_to_token_num + expert_id )
467465
468466 # do regular gemm here
469467 offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
470468 token_mask = offs_am < cur_m
471- a_m_index = tl .load (
472- expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am ,
473- mask = token_mask ,
474- other = 0 ,
475- )
469+
470+ if not OUT_SORTED or not TOKEN_INPUT_USE_TMA :
471+ a_m_index = tl .load (
472+ expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am ,
473+ mask = token_mask ,
474+ other = 0 ,
475+ )
476+
476477 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % n
477478 offs_k = tl .arange (0 , BLOCK_SIZE_K )
478479
@@ -493,12 +494,17 @@ def grouped_matmul_kernel(
493494 ab_scale = a_scale * b_scale
494495
495496 if NEED_TRANS :
496- a_ptrs = token_ptr + (a_m_index // topk_num )[None , :] * token_stride_0 + offs_k [:, None ]
497- b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [None , :] + offs_bn [:, None ] * weight_stride_1
497+ if not TOKEN_INPUT_USE_TMA :
498+ a_ptrs = token_ptr + (a_m_index // topk_num )[None , :] * token_stride_0 + offs_k [:, None ]
499+ if not WEIGHT_USE_TMA :
500+ b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [None , :] + offs_bn [:, None ] * weight_stride_1
498501 accumulator = tl .zeros ((BLOCK_SIZE_N , BLOCK_SIZE_M ), dtype = tl .float32 )
499502 else :
500- a_ptrs = token_ptr + (a_m_index // topk_num )[:, None ] * token_stride_0 + offs_k [None , :]
501- b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [:, None ] + offs_bn [None , :] * weight_stride_1
503+ if not TOKEN_INPUT_USE_TMA :
504+ a_ptrs = token_ptr + (a_m_index // topk_num )[:, None ] * token_stride_0 + offs_k [None , :]
505+ if not WEIGHT_USE_TMA :
506+ b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [:, None ] + offs_bn [None , :] * weight_stride_1
507+
502508 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
503509
504510 for k_start in range (0 , k , BLOCK_SIZE_K ):
@@ -559,8 +565,10 @@ def grouped_matmul_kernel(
559565 else :
560566 accumulator += tl .dot (a , b )
561567
562- a_ptrs += BLOCK_SIZE_K
563- b_ptrs += BLOCK_SIZE_K
568+ if not TOKEN_INPUT_USE_TMA :
569+ a_ptrs += BLOCK_SIZE_K
570+ if not WEIGHT_USE_TMA :
571+ b_ptrs += BLOCK_SIZE_K
564572
565573 if NEED_TRANS :
566574 accumulator = accumulator .T
0 commit comments