@@ -44,7 +44,8 @@ def ref_expt_data(routing_data, n_gates, block_m):
4444@pytest .mark .parametrize ("n_expts_tot, n_expts_act" , [(128 , 4 ), (1500 , 8 )])
4545@pytest .mark .parametrize ("block_m" , [64 , 128 ])
4646@pytest .mark .parametrize ("use_expt_indx" , [False , True ])
47- def test_op (n_tokens , n_expts_tot , n_expts_act , block_m , use_expt_indx , device ):
47+ @pytest .mark .parametrize ("renormalize" , [True , False ])
48+ def test_op (n_tokens , n_expts_tot , n_expts_act , renormalize , block_m , use_expt_indx , device ):
4849 torch .manual_seed (2 )
4950 tri_logits = init_data (n_tokens , n_expts_tot , device = device ).detach ()
5051 ref_logits = tri_logits .clone ()
@@ -55,8 +56,11 @@ def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, use_expt_indx, device):
5556 ref_expt_indx = tri_expt_indx [:n_tokens ]
5657 else :
5758 tri_expt_indx = ref_expt_indx = None
58- ref_routing_data , ref_gather , ref_scatter = routing_torch (ref_logits , n_expts_act , ref_expt_indx )
59- tri_routing_data , tri_gather , tri_scatter = routing (tri_logits , n_expts_act , tri_expt_indx )
59+ if not renormalize :
60+ tri_logits = torch .softmax (tri_logits , dim = - 1 )
61+ ref_logits = torch .softmax (ref_logits , dim = - 1 )
62+ ref_routing_data , ref_gather , ref_scatter = routing_torch (ref_logits , n_expts_act , renormalize , ref_expt_indx )
63+ tri_routing_data , tri_gather , tri_scatter = routing (tri_logits , n_expts_act , renormalize , tri_expt_indx )
6064 ref_metadata = ref_expt_data (ref_routing_data , n_tokens * n_expts_act , block_m )
6165 tri_metadata = compute_metadata (tri_routing_data , n_tokens * n_expts_act , block_m )
6266
0 commit comments