@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
108108 assert global_num_experts != - 1
109109 assert a1q_scale is not None
110110
111- if expert_map is not None :
112- "Translate info from expert_map to topk_ids"
113- local_topk_ids = torch .where (
114- expert_map [topk_ids ] != - 1 , expert_map [topk_ids ], - 1
115- )
116- else :
117- local_topk_ids = topk_ids
118-
119- topk = local_topk_ids .size (1 )
111+ topk = topk_ids .size (1 )
120112 local_E = w1 .size (0 )
121113
122114 if use_batched_format :
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
164156 # during offset calculations
165157 expert_offsets = expert_offsets .to (torch .int64 )
166158 else :
167- problem_sizes1 = torch .empty (
168- (global_num_experts , 3 ), dtype = torch .int32 , device = device
169- )
170- problem_sizes2 = torch .empty (
171- (global_num_experts , 3 ), dtype = torch .int32 , device = device
172- )
159+ problem_sizes1 = torch .empty ((local_E , 3 ), dtype = torch .int32 , device = device )
160+ problem_sizes2 = torch .empty ((local_E , 3 ), dtype = torch .int32 , device = device )
173161
174162 num_expert = global_num_experts if expert_map is None else expert_map .size (0 )
175163 # permuted a1q reuses workspace2
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
182170 expert_map ,
183171 permuted_hidden_states = a1q_perm ,
184172 )
185- expert_offsets = expert_first_token_offset [: - 1 ]
186-
187- ops .get_cutlass_moe_mm_problem_sizes (
188- local_topk_ids , problem_sizes1 , problem_sizes2 , global_num_experts , N , K
173+ # swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
174+ swap_ab = a1q . size ( 0 ) <= 64
175+ ops .get_cutlass_moe_mm_problem_sizes_from_expert_offsets (
176+ expert_first_token_offset , problem_sizes1 , problem_sizes2 , N , K , swap_ab
189177 )
178+ expert_offsets = expert_first_token_offset [:- 1 ]
190179
191180 if not per_act_token and (expert_map is not None or use_batched_format ):
192181 # this is necessary to avoid imprecise scale calculation caused by
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
240229 permuted_hidden_states = mm2_out ,
241230 topk_weights = topk_weights ,
242231 inv_permuted_idx = inv_perm ,
243- expert_first_token_offset = (
244- expert_first_token_offset if expert_map is not None else None
245- ),
232+ expert_first_token_offset = expert_first_token_offset ,
246233 )
247234
248235
@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
772759 f"w1 hidden size mismatch: got { w1 .size (2 ) * 8 } , expected { K = } "
773760 )
774761
775- # Translate info from expert_map to topk_ids
776- if expert_map is not None :
777- local_topk_ids = torch .where (
778- expert_map [topk_ids ] != - 1 , expert_map [topk_ids ], - 1
779- )
780- else :
781- local_topk_ids = topk_ids
782-
783- topk = local_topk_ids .size (1 )
762+ topk = topk_ids .size (1 )
784763 a1q_perm = _resize_cache (workspace2 .view (dtype = torch .float8_e4m3fn ), (M * topk , K ))
785764 mm1_out = _resize_cache (workspace13 , (M * topk , N * 2 ))
786765 act_out = _resize_cache (workspace2 , (M * topk , N ))
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
790769 )
791770 mm2_out = _resize_cache (workspace2 , (M * topk , K ))
792771
793- problem_sizes1 = torch .empty (
794- (global_num_experts , 3 ), dtype = torch .int32 , device = device
795- )
796- problem_sizes2 = torch .empty (
797- (global_num_experts , 3 ), dtype = torch .int32 , device = device
798- )
772+ problem_sizes1 = torch .empty ((local_E , 3 ), dtype = torch .int32 , device = device )
773+ problem_sizes2 = torch .empty ((local_E , 3 ), dtype = torch .int32 , device = device )
799774
800775 num_expert = global_num_experts if expert_map is None else expert_map .size (0 )
801776 # permuted a1q reuses workspace2
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
808783 expert_map ,
809784 permuted_hidden_states = a1q_perm ,
810785 )
811- expert_offsets = expert_first_token_offset [:- 1 ]
812-
813- # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
814- ops .get_cutlass_moe_mm_problem_sizes (
815- local_topk_ids ,
816- problem_sizes1 ,
817- problem_sizes2 ,
818- global_num_experts ,
819- N ,
820- K ,
821- force_swap_ab = True ,
786+ # for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
787+ ops .get_cutlass_moe_mm_problem_sizes_from_expert_offsets (
788+ expert_first_token_offset , problem_sizes1 , problem_sizes2 , N , K , True
822789 )
790+ expert_offsets = expert_first_token_offset [:- 1 ]
823791
824792 ops .cutlass_w4a8_moe_mm (
825793 mm1_out ,
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
866834 permuted_hidden_states = mm2_out ,
867835 topk_weights = topk_weights ,
868836 inv_permuted_idx = inv_perm ,
869- expert_first_token_offset = (
870- expert_first_token_offset if expert_map is not None else None
871- ),
837+ expert_first_token_offset = expert_first_token_offset ,
872838 )
873839
874840
0 commit comments