@@ -292,7 +292,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
292292 assert r is None
293293
294294
295- @pytest .mark .skip (reason = "https://nvbugs/5467531" )
296295@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
297296 reason = "needs 4 GPUs to run this test" )
298297@pytest .mark .parametrize ("alltoall_method_type" , [
@@ -304,7 +303,7 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
304303
305304 world_size = 4
306305 dtype = torch .bfloat16
307- HIDDEN_SIZE = 2560
306+ HIDDEN_SIZE = 4096
308307 INTERMEDIATE_SIZE = 1536
309308 NUM_EXPERTS = 72
310309 TOP_K = 6
@@ -320,8 +319,8 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
320319 x_list = []
321320 m = MAX_NUM_TOKENS
322321 while m >= 1 :
323- x = torch .randn ((m , HIDDEN_SIZE ), dtype = dtype , device = "cuda" )
324- x_list .append (x . cuda ( i ) )
322+ x = torch .randn ((m , HIDDEN_SIZE ), dtype = dtype )
323+ x_list .append (x )
325324 m //= 2
326325
327326 x_abs_max = torch .cat ([x .flatten () for x in x_list ]).abs ().max ().float ()
@@ -366,49 +365,37 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
366365 w3_sf_block_unswizzled = torch .ops .trtllm .block_scale_interleave_reverse (
367366 w3_sf_block .cpu ().view (INTERMEDIATE_SIZE , - 1 ))
368367
369- w1_input_scale = x_sf_global .cuda (i )
370- w2_input_scale = x_sf_global .cuda (i )
371- w3_input_scale = x_sf_global .cuda (i )
368+ weights [f"{ expert_id } .w1.weight" ] = w1_weight_nvfp4 .cpu ()
369+ weights [f"{ expert_id } .w2.weight" ] = w2_weight_nvfp4 .cpu ()
370+ weights [f"{ expert_id } .w3.weight" ] = w3_weight_nvfp4 .cpu ()
371+ weights [f"{ expert_id } .w1.weight_scale" ] = w1_sf_block_unswizzled
372+ weights [f"{ expert_id } .w2.weight_scale" ] = w2_sf_block_unswizzled
373+ weights [f"{ expert_id } .w3.weight_scale" ] = w3_sf_block_unswizzled
372374
373- weights [f"{ expert_id } .w1.weight" ] = w1_weight_nvfp4 .cuda (i )
374- weights [f"{ expert_id } .w2.weight" ] = w2_weight_nvfp4 .cuda (i )
375- weights [f"{ expert_id } .w3.weight" ] = w3_weight_nvfp4 .cuda (i )
376- weights [
377- f"{ expert_id } .w1.weight_scale" ] = w1_sf_block_unswizzled .cuda (i )
378- weights [
379- f"{ expert_id } .w2.weight_scale" ] = w2_sf_block_unswizzled .cuda (i )
380- weights [
381- f"{ expert_id } .w3.weight_scale" ] = w3_sf_block_unswizzled .cuda (i )
382-
383- weights [f"{ expert_id } .w1.input_scale" ] = 1.0 / w1_input_scale .cuda (
384- i )
385- weights [f"{ expert_id } .w2.input_scale" ] = 1.0 / w2_input_scale .cuda (
386- i )
387- weights [f"{ expert_id } .w3.input_scale" ] = 1.0 / w3_input_scale .cuda (
388- i )
389- weights [f"{ expert_id } .w1.weight_scale_2" ] = 1.0 / w3_w1_global .cuda (
390- i )
391- weights [f"{ expert_id } .w2.weight_scale_2" ] = 1.0 / w2_sf_global .cuda (
392- i )
393- weights [f"{ expert_id } .w3.weight_scale_2" ] = 1.0 / w3_w1_global .cuda (
394- i )
375+ weights [f"{ expert_id } .w1.input_scale" ] = 1.0 / x_sf_global
376+ weights [f"{ expert_id } .w2.input_scale" ] = 1.0 / x_sf_global
377+ weights [f"{ expert_id } .w3.input_scale" ] = 1.0 / x_sf_global
378+ weights [f"{ expert_id } .w1.weight_scale_2" ] = 1.0 / w3_w1_global .cpu ()
379+ weights [f"{ expert_id } .w2.weight_scale_2" ] = 1.0 / w2_sf_global .cpu ()
380+ weights [f"{ expert_id } .w3.weight_scale_2" ] = 1.0 / w3_w1_global .cpu ()
395381
396382 x_list_world .append (x_list )
397383 weights_world .append (weights )
384+ torch .cuda .synchronize ()
398385
399- def per_rank_test_fused_moe_alltoall (job_id ):
386+ def per_rank_test_fused_moe_alltoall (job_id , weights , x_list ):
400387 routing_method = DefaultMoeRoutingMethod (top_k = TOP_K )
401388 mapping = Mapping (world_size = world_size ,
402- rank = mpi_rank () ,
389+ rank = job_id ,
403390 tp_size = world_size ,
404391 moe_ep_size = world_size ,
405392 moe_tp_size = 1 ,
406393 enable_attention_dp = True )
407394 torch .cuda .set_device (mapping .rank )
408395 torch .manual_seed (mapping .rank )
409396
410- x_list = x_list_world [ mapping . rank ]
411- weights = weights_world [ mapping . rank ]
397+ weights = { k : v . cuda () for k , v in weights . items ()}
398+ x_list = [ x . cuda () for x in x_list ]
412399
413400 quant_config = QuantConfig (quant_algo = QuantAlgo .NVFP4 )
414401 with mock .patch .object (WideEPMoE ,
@@ -459,6 +446,16 @@ def per_rank_test_fused_moe_alltoall(job_id):
459446 router_logits ,
460447 all_rank_num_tokens = all_rank_num_tokens ,
461448 use_dp_padding = False )
449+ # Verify the fake impl is correct.
450+ output_fake = alltoall_model .forward_fake (
451+ x ,
452+ router_logits ,
453+ all_rank_num_tokens = all_rank_num_tokens ,
454+ use_dp_padding = False )
455+ assert output_fake .shape == output .shape
456+ assert output_fake .dtype == output .dtype
457+ if len (output .shape ) == 3 :
458+ output = torch .sum (output , dim = 1 , keepdim = False )
462459 ref_output = ref_model .forward (
463460 x ,
464461 router_logits ,
@@ -470,8 +467,10 @@ def per_rank_test_fused_moe_alltoall(job_id):
470467 m //= 2
471468
472469 with MPIPoolExecutor (max_workers = world_size ) as executor :
473- results = executor .map (per_rank_test_fused_moe_alltoall ,
474- range (world_size ))
470+ results = executor .map (
471+ per_rank_test_fused_moe_alltoall ,
472+ * zip (* [(i , weights_world [i ], x_list_world [i ])
473+ for i in range (world_size )]))
475474 for r in results :
476475 assert r is None
477476
0 commit comments