@@ -544,8 +544,8 @@ def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config):
544544 )
545545
546546
547- def _test_grouped_vs_non_grouped_quantize_helper (tp_size , ep_size , etp_size , rank , size ):
548- """Test that grouped and non-grouped MoE models produce similar amax values."""
547+ def _test_te_grouped_vs_sequential_quantize_helper (tp_size , ep_size , etp_size , rank , size ):
548+ """Test that TEGrouped and sequential MoE models produce similar amax values."""
549549 initialize_for_megatron (
550550 tensor_model_parallel_size = tp_size ,
551551 expert_model_parallel_size = ep_size ,
@@ -559,8 +559,8 @@ def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, ran
559559 def forward_fn (model ):
560560 return megatron_prefill (model , prompt_tokens )
561561
562- # Create grouped MoE model
563- grouped_moe_model = _gpt_model_provider (
562+ # Create TEGrouped MoE model
563+ te_grouped_moe_model = _gpt_model_provider (
564564 tp_size = tp_size ,
565565 ep_size = ep_size ,
566566 etp_size = etp_size ,
@@ -569,14 +569,14 @@ def forward_fn(model):
569569 use_te = True ,
570570 num_moe_experts = 4 ,
571571 )
572- num_grouped_mlp = sum (
573- isinstance (module , TEGroupedMLP ) for module in grouped_moe_model .modules ()
572+ num_te_grouped_mlp = sum (
573+ isinstance (module , TEGroupedMLP ) for module in te_grouped_moe_model .modules ()
574574 )
575- assert num_grouped_mlp == 4 , (
576- f"TEGrupedMoEModel has { num_grouped_mlp } TEGroupedMLP modules, it should have 4"
575+ assert num_te_grouped_mlp == 4 , (
576+ f"TEGrupedMoEModel has { num_te_grouped_mlp } TEGroupedMLP modules, it should have 4"
577577 )
578578
579- # Create non-grouped MoE model
579+ # Create sequential MoE model
580580 sequential_moe_model = _gpt_model_provider (
581581 tp_size = tp_size ,
582582 ep_size = ep_size ,
@@ -592,37 +592,37 @@ def forward_fn(model):
592592 f"SequentialMoEModel has { num_sequential_mlp } SequentialMLP modules, it should have 4"
593593 )
594594 # Copy weights from grouped to non-grouped model
595- copy_weights_from_grouped_to_non_grouped (grouped_moe_model , sequential_moe_model )
595+ copy_weights_from_grouped_to_non_grouped (te_grouped_moe_model , sequential_moe_model )
596596
597597 # Compare model outputs before quantization
598- grouped_moe_output = forward_fn (grouped_moe_model )
599- non_grouped_moe_output = forward_fn (sequential_moe_model )
600- assert torch .allclose (grouped_moe_output , non_grouped_moe_output , atol = 1e-6 , rtol = 1e-6 )
598+ te_grouped_moe_output = forward_fn (te_grouped_moe_model )
599+ sequential_moe_output = forward_fn (sequential_moe_model )
600+ assert torch .allclose (te_grouped_moe_output , sequential_moe_output , atol = 1e-6 , rtol = 1e-6 )
601601
602602 # Quantize grouped model
603- mtq .quantize (grouped_moe_model , mtq .FP8_DEFAULT_CFG , forward_fn )
603+ mtq .quantize (te_grouped_moe_model , mtq .FP8_DEFAULT_CFG , forward_fn )
604604
605605 # Quantize non-grouped model
606606 mtq .quantize (sequential_moe_model , mtq .FP8_DEFAULT_CFG , forward_fn )
607607
608608 # Compare model outputs after quantization
609- grouped_moe_quant_output = forward_fn (grouped_moe_model )
610- non_grouped_moe_quant_output = forward_fn (sequential_moe_model )
609+ te_grouped_moe_quant_output = forward_fn (te_grouped_moe_model )
610+ sequential_moe_quant_output = forward_fn (sequential_moe_model )
611611 assert torch .allclose (
612- grouped_moe_quant_output , non_grouped_moe_quant_output , atol = 1e-6 , rtol = 1e-6
612+ te_grouped_moe_quant_output , sequential_moe_quant_output , atol = 1e-6 , rtol = 1e-6
613613 )
614614
615615
616- def test_grouped_vs_non_grouped_quantize ():
617- """Test that grouped and non-grouped MoE models produce similar quantized models."""
616+ def test_te_grouped_vs_sequential_quantize ():
617+ """Test that TEGrouped and sequential MoE models produce similar quantized models."""
618618
619619 size = torch .cuda .device_count ()
620620 if size < 4 :
621621 pytest .skip ("Requires at least 4 GPUs for expert parallel test" )
622622
623623 spawn_multiprocess_job (
624624 size = size ,
625- job = partial (_test_grouped_vs_non_grouped_quantize_helper , 1 , 2 , 2 ),
625+ job = partial (_test_te_grouped_vs_sequential_quantize_helper , 1 , 2 , 2 ),
626626 backend = "nccl" ,
627627 )
628628
@@ -666,8 +666,8 @@ def forward_fn(model):
666666 if isinstance (module , mtq .nn .TensorQuantizer ):
667667 # Check if this is an expert quantizer
668668 is_expert_quantizer = (
669- "local_experts" in name # Non-grouped MoE
670- or ("experts" in name and "linear_fc" in name ) # Grouped MoE
669+ "local_experts" in name # sequential MoE
670+ or ("experts" in name and "linear_fc" in name ) # TEGrouped MoE
671671 )
672672
673673 if is_expert_quantizer and hasattr (module , "_amax" ):
0 commit comments