@@ -27,6 +27,23 @@ def _make_expt_dict_for_mode(n_shards, n_expts_tot, affinity_mode):
2727 raise ValueError (f"Unknown affinity mode: { affinity_mode } " ) from exc
2828
2929
30+ def _make_y_indx_for_mode (n_tokens_global , n_expts_tot , n_expts_act , n_shards , affinity_mode , dev ):
31+ y_indx_global = None
32+ if affinity_mode == "uniform" :
33+ if n_expts_tot % n_shards != 0 :
34+ raise ValueError ("uniform affinity requires experts evenly divisible by shards" )
35+ expts_per_rank = n_expts_tot // n_shards
36+ rounds = (n_expts_act + n_shards - 1 ) // n_shards
37+ if rounds > expts_per_rank :
38+ raise ValueError ("round-robin selection exceeds experts available per shard" )
39+ order = torch .arange (n_expts_act , device = dev , dtype = torch .int32 )
40+ shard_order = order % n_shards
41+ intra_shard = order // n_shards
42+ round_robin_indx = (shard_order * expts_per_rank + intra_shard ).to (torch .int16 )
43+ y_indx_global = round_robin_indx .unsqueeze (0 ).expand (n_tokens_global , - 1 ).contiguous ()
44+ return y_indx_global
45+
46+
3047# ------------------------------------------------------------
3148# fixture
3249# ------------------------------------------------------------
@@ -102,8 +119,8 @@ def test_make_expt_assignment(n_expts_shard, n_expts_tot, affinity_mode):
102119# ------------------------------------------------------------
103120
104121
105- def routing (logits , n_expts_act , all_gather = False ):
106- sparse_logits = topk (logits , n_expts_act , all_gather = all_gather )
122+ def routing (logits , n_expts_act , all_gather = False , y_indx = None ):
123+ sparse_logits = topk (logits , n_expts_act , all_gather = all_gather , y_indx = y_indx )
107124 dispatch_indx = sparse_logits .mask_metadata .col_sorted_indx
108125 combine_indx = sparse_logits .mask_metadata .row_sorted_indx
109126 ragged_batch_metadata = make_ragged_tensor_metadata (sparse_logits .mask_metadata .col_sum , dispatch_indx .shape [0 ])
@@ -115,17 +132,18 @@ def routing(logits, n_expts_act, all_gather=False):
115132 return routing_data , gather_idx , scatter_idx , sparse_logits .indx
116133
117134
118- def mixture_of_expt_nosharded (x_global , l_global , w_global , b_global , n_expts_act ):
119- rdata , combine_indx , dispatch_indx , _ = routing (l_global , n_expts_act )
135+ def mixture_of_expt_nosharded (x_global , l_global , w_global , b_global , n_expts_act , y_indx = None ):
136+ rdata , combine_indx , dispatch_indx , _ = routing (l_global , n_expts_act , y_indx = y_indx )
120137 y_global = matmul_ogs (x_global , w_global , b_global , rdata , gather_indx = combine_indx , scatter_indx = dispatch_indx )
121138 return y_global
122139
123140
124- def mixture_of_expt_epsharded (x_dp_local , l_dp_local , w_ep_local , b_ep_local , expt_assignment , n_expts_act ):
141+ def mixture_of_expt_epsharded (x_dp_local , l_dp_local , w_ep_local , b_ep_local , expt_assignment , n_expts_act ,
142+ y_indx = None ):
125143 rank = dist .get_rank ()
126144 expt_map = expt_assignment .expt_map [rank , :]
127145 # active global logits (sparse)
128- l_global_active = topk (l_dp_local , n_expts_act , apply_softmax = True , all_gather = True )
146+ l_global_active = topk (l_dp_local , n_expts_act , apply_softmax = True , all_gather = True , y_indx = y_indx )
129147 # expert histogram, dispatch/combine indx
130148 active_indx = l_global_active .indx
131149 expt_sizes = l_global_active .mask_metadata .col_sum
@@ -264,7 +282,15 @@ def _run_expert_sharding(rank, world_size, *, n_tokens, d_model, n_expts_tot, n_
264282 l_dp_local = l_global [first_token_indx :last_token_indx , :]
265283 # routing
266284 # test correctness
267- y_global_ref = mixture_of_expt_nosharded (x_global , l_global , w_global , b_global , n_expts_act )
285+ y_indx_global = _make_y_indx_for_mode (n_tokens_global , n_expts_tot , n_expts_act , n_shards , affinity_mode , dev )
286+ y_global_ref = mixture_of_expt_nosharded (
287+ x_global ,
288+ l_global ,
289+ w_global ,
290+ b_global ,
291+ n_expts_act ,
292+ y_indx = y_indx_global ,
293+ )
268294
269295 def run_mixture ():
270296 return mixture_of_expt_epsharded (
@@ -274,6 +300,7 @@ def run_mixture():
274300 b_ep_local ,
275301 expt_assignment ,
276302 n_expts_act ,
303+ y_indx = y_indx_global ,
277304 )
278305
279306 # test cuda graph capture + replay with symmetric memory
0 commit comments