@@ -324,6 +324,7 @@ def is_moe(module: nn.Module) -> bool:
324324 "PhimoeSparseMoeBlock" .lower (),
325325 "DeepseekMoE" .lower (),
326326 "Qwen2MoeSparseMoeBlock" .lower (),
327+ "Qwen3MoeSparseMoeBlock" .lower (),
327328 ]
328329
329330
@@ -969,26 +970,75 @@ def get_stacked_scaling_factors(experts, get_function, module_name):
969970 return config
970971
971972
972- @contextmanager
973- def set_zero_amax_for_uncalibrated_experts (experts : nn .Module ):
974- """For experts that does not have valid amax value of input quantizer, we set them to 0."""
973+ def get_expert_linear_names (module : nn .Module ) -> list [str ]:
974+ """Get the list of linear names for the experts."""
975+ if type (module ).__name__ .lower () in [
976+ "Qwen2MoeSparseMoeBlock" .lower (),
977+ "Qwen3MoeSparseMoeBlock" .lower (),
978+ "DeepseekMoE" .lower (),
979+ ]:
980+ return ["gate_proj" , "down_proj" , "up_proj" ]
981+ elif type (module ).__name__ .lower () in "MixtralMoeSparseMoeBlock" .lower ():
982+ return ["linear_fc1" , "linear_fc2" ]
983+ elif type (module ).__name__ .lower () in "DBRXMoeSparseMoeBlock" .lower ():
984+ return ["w1_linear" , "w2_linear" , "v1_linear" ]
985+ else :
986+ # assuing w1, w2, w3 by default
987+ return ["w1" , "w2" , "w3" ]
988+
989+
990+ def set_amax_for_uncalibrated_experts (experts : nn .Module , set_amax_value : float | None = None ):
991+ """Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts.
992+
993+ Args:
994+ experts: a list of experts
995+ set_amax_value: set amax value to the given value.
996+ If None, set amax value to the max of existing amax value from other experts.
997+
998+ Returns:
999+ uncalibrated_experts: a list of uncalibrated experts
1000+ """
9751001 uncalibrated_experts = []
1002+ # get the max amax value from all experts
1003+ if set_amax_value is None :
1004+ amax_values = [
1005+ module .input_quantizer .amax
1006+ for module in experts
1007+ if (
1008+ hasattr (module , "input_quantizer" )
1009+ and module .input_quantizer is not None
1010+ and module .input_quantizer .is_enabled
1011+ )
1012+ and module .input_quantizer .amax is not None
1013+ ]
1014+ if len (amax_values ) == 0 :
1015+ return uncalibrated_experts
1016+ set_amax_value = torch .max (torch .stack (amax_values ))
1017+
9761018 for module in experts :
9771019 if (
9781020 hasattr (module , "input_quantizer" )
9791021 and module .input_quantizer is not None
9801022 and module .input_quantizer .is_enabled
9811023 ) and module .input_quantizer .amax is None :
9821024 warn (
983- f"Missing amax value for { module } input_quantizer. Setting it to 0 for checkpoint export. "
1025+ f"Missing amax value for { module } input_quantizer. Setting it to { set_amax_value } for export. "
9841026 f"This typically occurs in MoE models when certain experts are not activated during calibration. "
9851027 f"Consider increasing your calibration dataset size to ensure all experts are exercised."
9861028 )
9871029 # Use float32 dtype explicitly to ensure we create a floating point tensor
9881030 module .input_quantizer .amax = torch .tensor (
989- 0.0 , dtype = torch .float32 , device = module .weight_quantizer .amax .device
1031+ set_amax_value , dtype = torch .float32 , device = module .weight_quantizer .amax .device
9901032 )
9911033 uncalibrated_experts .append (module )
1034+
1035+
1036+ @contextmanager
1037+ def set_amax_for_uncalibrated_experts_context (
1038+ experts : nn .Module , set_amax_value : float | None = None
1039+ ):
1040+ """Set amax for uncalibrated experts in a context manager."""
1041+ uncalibrated_experts = set_amax_for_uncalibrated_experts (experts , set_amax_value )
9921042 yield
9931043 if uncalibrated_experts :
9941044 for module in uncalibrated_experts :
@@ -1022,12 +1072,13 @@ def build_stacked_experts(
10221072 )
10231073
10241074 # Set amax to 0 for uncalibrated experts
1025- with set_zero_amax_for_uncalibrated_experts (
1075+ with set_amax_for_uncalibrated_experts_context (
10261076 [
10271077 expert_getter (experts , i , module_name )
10281078 for module_name in linear_names
10291079 for i in range (num_experts )
1030- ]
1080+ ],
1081+ 0 , # set amax to 0 for uncalibrated experts as we will calculate max across all experts later
10311082 ):
10321083 # Pre-fuse W1 and W3
10331084 if len (linear_names ) == 3 :
@@ -1121,12 +1172,14 @@ def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
11211172 )
11221173 elif decoder_type == "qwen" :
11231174 config .router = build_linear_config (module .gate , LINEAR_ROW )
1124- preprocess_linear_fusion ([module .shared_expert .gate_proj , module .shared_expert .up_proj ])
1125- config .shared_expert = build_mlp_config (
1126- module .shared_expert , decoder_type , merge_gate_fc = True
1127- )
1128- config .shared_expert_gate = build_linear_config (module .shared_expert_gate , LINEAR_ROW )
1129- config .shared_expert_gate .tp = False
1175+ # Qwen3 doesn't have shared expert
1176+ if hasattr (module , "shared_expert" ):
1177+ preprocess_linear_fusion ([module .shared_expert .gate_proj , module .shared_expert .up_proj ])
1178+ config .shared_expert = build_mlp_config (
1179+ module .shared_expert , decoder_type , merge_gate_fc = True
1180+ )
1181+ config .shared_expert_gate = build_linear_config (module .shared_expert_gate , LINEAR_ROW )
1182+ config .shared_expert_gate .tp = False
11301183 else :
11311184 raise NotImplementedError (f"{ decoder_type } not supported" )
11321185
0 commit comments