@@ -324,6 +324,7 @@ def is_moe(module: nn.Module) -> bool:
324
324
"PhimoeSparseMoeBlock" .lower (),
325
325
"DeepseekMoE" .lower (),
326
326
"Qwen2MoeSparseMoeBlock" .lower (),
327
+ "Qwen3MoeSparseMoeBlock" .lower (),
327
328
]
328
329
329
330
@@ -969,26 +970,75 @@ def get_stacked_scaling_factors(experts, get_function, module_name):
969
970
return config
970
971
971
972
972
- @contextmanager
973
- def set_zero_amax_for_uncalibrated_experts (experts : nn .Module ):
974
- """For experts that does not have valid amax value of input quantizer, we set them to 0."""
973
+ def get_expert_linear_names (module : nn .Module ) -> list [str ]:
974
+ """Get the list of linear names for the experts."""
975
+ if type (module ).__name__ .lower () in [
976
+ "Qwen2MoeSparseMoeBlock" .lower (),
977
+ "Qwen3MoeSparseMoeBlock" .lower (),
978
+ "DeepseekMoE" .lower (),
979
+ ]:
980
+ return ["gate_proj" , "down_proj" , "up_proj" ]
981
+ elif type (module ).__name__ .lower () in "MixtralMoeSparseMoeBlock" .lower ():
982
+ return ["linear_fc1" , "linear_fc2" ]
983
+ elif type (module ).__name__ .lower () in "DBRXMoeSparseMoeBlock" .lower ():
984
+ return ["w1_linear" , "w2_linear" , "v1_linear" ]
985
+ else :
986
+ # assuing w1, w2, w3 by default
987
+ return ["w1" , "w2" , "w3" ]
988
+
989
+
990
+ def set_amax_for_uncalibrated_experts (experts : nn .Module , set_amax_value : float | None = None ):
991
+ """Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts.
992
+
993
+ Args:
994
+ experts: a list of experts
995
+ set_amax_value: set amax value to the given value.
996
+ If None, set amax value to the max of existing amax value from other experts.
997
+
998
+ Returns:
999
+ uncalibrated_experts: a list of uncalibrated experts
1000
+ """
975
1001
uncalibrated_experts = []
1002
+ # get the max amax value from all experts
1003
+ if set_amax_value is None :
1004
+ amax_values = [
1005
+ module .input_quantizer .amax
1006
+ for module in experts
1007
+ if (
1008
+ hasattr (module , "input_quantizer" )
1009
+ and module .input_quantizer is not None
1010
+ and module .input_quantizer .is_enabled
1011
+ )
1012
+ and module .input_quantizer .amax is not None
1013
+ ]
1014
+ if len (amax_values ) == 0 :
1015
+ return uncalibrated_experts
1016
+ set_amax_value = torch .max (torch .stack (amax_values ))
1017
+
976
1018
for module in experts :
977
1019
if (
978
1020
hasattr (module , "input_quantizer" )
979
1021
and module .input_quantizer is not None
980
1022
and module .input_quantizer .is_enabled
981
1023
) and module .input_quantizer .amax is None :
982
1024
warn (
983
- f"Missing amax value for { module } input_quantizer. Setting it to 0 for checkpoint export. "
1025
+ f"Missing amax value for { module } input_quantizer. Setting it to { set_amax_value } for export. "
984
1026
f"This typically occurs in MoE models when certain experts are not activated during calibration. "
985
1027
f"Consider increasing your calibration dataset size to ensure all experts are exercised."
986
1028
)
987
1029
# Use float32 dtype explicitly to ensure we create a floating point tensor
988
1030
module .input_quantizer .amax = torch .tensor (
989
- 0.0 , dtype = torch .float32 , device = module .weight_quantizer .amax .device
1031
+ set_amax_value , dtype = torch .float32 , device = module .weight_quantizer .amax .device
990
1032
)
991
1033
uncalibrated_experts .append (module )
1034
+
1035
+
1036
+ @contextmanager
1037
+ def set_amax_for_uncalibrated_experts_context (
1038
+ experts : nn .Module , set_amax_value : float | None = None
1039
+ ):
1040
+ """Set amax for uncalibrated experts in a context manager."""
1041
+ uncalibrated_experts = set_amax_for_uncalibrated_experts (experts , set_amax_value )
992
1042
yield
993
1043
if uncalibrated_experts :
994
1044
for module in uncalibrated_experts :
@@ -1022,12 +1072,13 @@ def build_stacked_experts(
1022
1072
)
1023
1073
1024
1074
# Set amax to 0 for uncalibrated experts
1025
- with set_zero_amax_for_uncalibrated_experts (
1075
+ with set_amax_for_uncalibrated_experts_context (
1026
1076
[
1027
1077
expert_getter (experts , i , module_name )
1028
1078
for module_name in linear_names
1029
1079
for i in range (num_experts )
1030
- ]
1080
+ ],
1081
+ 0 , # set amax to 0 for uncalibrated experts as we will calculate max across all experts later
1031
1082
):
1032
1083
# Pre-fuse W1 and W3
1033
1084
if len (linear_names ) == 3 :
@@ -1121,12 +1172,14 @@ def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
1121
1172
)
1122
1173
elif decoder_type == "qwen" :
1123
1174
config .router = build_linear_config (module .gate , LINEAR_ROW )
1124
- preprocess_linear_fusion ([module .shared_expert .gate_proj , module .shared_expert .up_proj ])
1125
- config .shared_expert = build_mlp_config (
1126
- module .shared_expert , decoder_type , merge_gate_fc = True
1127
- )
1128
- config .shared_expert_gate = build_linear_config (module .shared_expert_gate , LINEAR_ROW )
1129
- config .shared_expert_gate .tp = False
1175
+ # Qwen3 doesn't have shared expert
1176
+ if hasattr (module , "shared_expert" ):
1177
+ preprocess_linear_fusion ([module .shared_expert .gate_proj , module .shared_expert .up_proj ])
1178
+ config .shared_expert = build_mlp_config (
1179
+ module .shared_expert , decoder_type , merge_gate_fc = True
1180
+ )
1181
+ config .shared_expert_gate = build_linear_config (module .shared_expert_gate , LINEAR_ROW )
1182
+ config .shared_expert_gate .tp = False
1130
1183
else :
1131
1184
raise NotImplementedError (f"{ decoder_type } not supported" )
1132
1185
0 commit comments