1010import triton_kernels
1111import triton_kernels .swiglu
1212from triton_kernels .reduce import reduce
13- from triton_kernels .matmul_ogs import RoutingData , GatherIndx , ScatterIndx
1413from triton_kernels .topk import topk
15- from triton_kernels .matmul_ogs import matmul_ogs , PrecisionConfig , FlexCtx , FnSpecs , FusedActivation
14+ from triton_kernels .matmul import matmul , PrecisionConfig , FlexCtx , FnSpecs , FusedActivation
1615from triton_kernels .target_info import get_cdna_version , is_hip , is_cuda , cuda_capability_geq
1716from triton_kernels .tensor_details import layout
18- from triton_kernels .tensor import make_ragged_tensor_metadata , remap_ragged_tensor_metadata
17+ from triton_kernels .tensor import RaggedTensorMetadata , make_ragged_tensor_metadata , remap_ragged_tensor_metadata
1918from triton_kernels .distributed import make_expt_dict_uniform , make_expt_assignment , convert_dp_to_ep , convert_ep_to_dp , ExptAssignment , symm_mem_pool
2019
2120from bench_utils import quantize_weight
@@ -40,7 +39,7 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O
4039 return make_expt_assignment (EP , n_expts_tot , expt_dict , device )
4140
4241
43- def initialize_matmul_ogs (
42+ def initialize_matmul (
4443 batch : int ,
4544 dim1 : int ,
4645 dim2 : int ,
@@ -52,7 +51,7 @@ def initialize_matmul_ogs(
5251 return
5352 world_size = dist .get_world_size ()
5453 device = torch .cuda .current_device ()
55- symm_mem_pool .initialize_matmul_ogs (
54+ symm_mem_pool .initialize_matmul (
5655 n_tokens_global = batch ,
5756 d_input = dim1 ,
5857 d_model = dim2 ,
@@ -146,8 +145,7 @@ def routing(
146145 TP : int = 1 ,
147146 expt_assignment : Optional [ExptAssignment ] = None ,
148147 mode : Optional [str ] = None ,
149- ) -> Tuple [torch .Tensor , RoutingData , GatherIndx , ScatterIndx , Optional [ReduceScatterMetadata ]]:
150- n_expts_tot = logits .shape [- 1 ]
148+ ) -> Tuple [torch .Tensor , RaggedTensorMetadata , torch .Tensor , torch .Tensor , Optional [ReduceScatterMetadata ]]:
151149 if _is_distributed_launch () and mode :
152150 if mode == "ep_sharding" :
153151 if not expt_assignment :
@@ -170,29 +168,24 @@ def routing(
170168 logits_global_metadata = make_ragged_tensor_metadata (expt_sizes , dispatch_indx .shape [0 ])
171169 x = convert_dp_to_ep (x , expt_assignment , active_indx , dispatch_indx )
172170 logits_local_metadata = remap_ragged_tensor_metadata (logits_global_metadata , expt_map )
173- gate_scal = logits_global .vals .flatten ()[combine_indx ]
174- rdata = RoutingData (gate_scal , expt_sizes , n_expts_tot // EP , n_expts_act , logits_local_metadata )
175171 reduce_scatter_metadata = ReduceScatterMetadata (
176172 mode = mode ,
177173 active_indx = active_indx ,
178174 dispatch_indx = dispatch_indx ,
179175 combine_indx = combine_indx ,
180176 )
181- return x , rdata , None , None , reduce_scatter_metadata
177+ return x , logits_local_metadata , None , None , reduce_scatter_metadata
182178 else :
183179 raise NotImplementedError (f"Distributed routing mode { mode } is not implemented yet." )
184180 else :
185181 # If mode is not specified or we have a single process, we do single-GPU routing.
186182 logits = topk (logits , n_expts_act , y_indx = y_indx , apply_softmax = not sm_first )
187183 dispatch_indx = logits .mask_metadata .row_sorted_indx
188184 combine_indx = logits .mask_metadata .col_sorted_indx
189- ragged_batch_metadata = make_ragged_tensor_metadata (logits .mask_metadata .col_sum , dispatch_indx .shape [0 ])
190- gate_scal = logits .vals .flatten ()[combine_indx ]
191- routing_data = RoutingData (gate_scal , ragged_batch_metadata .slice_sizes , n_expts_tot , n_expts_act ,
192- ragged_batch_metadata )
193- gather_indx = GatherIndx (combine_indx , dispatch_indx )
194- scatter_indx = ScatterIndx (dispatch_indx , combine_indx )
195- return x , routing_data , gather_indx , scatter_indx , None
185+ ragged_metadata = make_ragged_tensor_metadata (logits .mask_metadata .col_sum , dispatch_indx .shape [0 ])
186+ gather_indx = combine_indx // n_expts_act
187+ scatter_indx = combine_indx
188+ return x , ragged_metadata , gather_indx , scatter_indx , None
196189
197190
198191def gather_ep (rank , world_size , param , TP , EP ):
@@ -276,14 +269,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
276269 w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None
277270
278271 # precision configs
279- pcg = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = wg_flex ), weight_scale = wg_scale )
272+ pcg = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = wg_flex ), b_mx_scale = wg_scale )
280273 act = FusedActivation (FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn , ("alpha" , "limit" ), reduction_n = 2 ),
281274 (1.0 , 1.0 ))
282- pc1 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex ), weight_scale = w1_scale )
283- pc2 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex ), weight_scale = w2_scale )
275+ pc1 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex ), b_mx_scale = w1_scale )
276+ pc2 = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex ), b_mx_scale = w2_scale )
284277 if rank == 0 :
285- pc1_full = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex_full ), weight_scale = w1_scale_full )
286- pc2_full = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex_full ), weight_scale = w2_scale_full )
278+ pc1_full = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w1_flex_full ), b_mx_scale = w1_scale_full )
279+ pc2_full = PrecisionConfig (flex_ctx = FlexCtx (rhs_data = w2_flex_full ), b_mx_scale = w2_scale_full )
287280 else :
288281 pc1_full = pc2_full = None
289282
@@ -296,7 +289,7 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
296289 xd = torch .randn ((batch // world_size , dim1 ), device = dev ).to (dtype_map [x_dtype ])
297290 x0 = all_gather (xd , dim = 0 )
298291 expt_assignment = create_expt_assignment (EP , n_expts_tot , torch .device (dev ))
299- symm_mem_pool .initialize_matmul_ogs (
292+ symm_mem_pool .initialize_matmul (
300293 n_tokens_global = batch ,
301294 d_input = dim1 ,
302295 d_model = dim2 ,
@@ -312,25 +305,25 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
312305 def single (x ):
313306 xg = x .to (wg .dtype if n_expts_tot > 1 else x .dtype )
314307 if n_expts_tot > 1 :
315- logits = matmul_ogs (xg , wg , bg , precision_config = pcg )
308+ logits = matmul (xg , wg , bg , precision_config = pcg )
316309 x , rdata , gi , si , _ = routing (x , logits , n_expts_act )
317310 else :
318311 rdata = gi = si = None
319- x = matmul_ogs (x , w1_full , b1_full , rdata , gather_indx = gi , precision_config = pc1_full , fused_activation = act )
320- return matmul_ogs (x , w2_full , b2_full , rdata , scatter_indx = si , precision_config = pc2_full )
312+ x = matmul (x , w1_full , b1_full , rdata , gather_indx = gi , precision_config = pc1_full , fused_activation = act )
313+ return matmul (x , w2_full , b2_full , rdata , scatter_indx = si , precision_config = pc2_full )
321314
322315 # distributed pass
323316 def distributed (x ):
324317 xg = x .to (wg .dtype if n_expts_tot > 1 else x .dtype )
325318 if n_expts_tot > 1 : # sparse
326- logits = matmul_ogs (xg , wg , bg , precision_config = pcg )
319+ logits = matmul (xg , wg , bg , precision_config = pcg )
327320 x , rdata , gi , si , metadata = routing (x , logits , n_expts_act , EP = EP , TP = TP , expt_assignment = expt_assignment ,
328321 mode = "ep_sharding" )
329322 else : # dense
330323 x = all_gather (x , dim = 0 )
331324 rdata = gi = si = metadata = None
332- x = matmul_ogs (x , w1 , b1 , rdata , gather_indx = gi , precision_config = pc1 , fused_activation = act )
333- x = matmul_ogs (x , w2 , b2 if rank % TP == 0 else None , rdata , scatter_indx = si , precision_config = pc2 )
325+ x = matmul (x , w1 , b1 , rdata , gather_indx = gi , precision_config = pc1 , fused_activation = act )
326+ x = matmul (x , w2 , b2 if rank % TP == 0 else None , rdata , scatter_indx = si , precision_config = pc2 )
334327 x = reduce_scatter (x , n_expts_act , metadata = metadata , expt_assignment = expt_assignment )
335328 # gather the result from all GPUs, just for verification
336329 return all_gather (x , dim = 0 )
0 commit comments