Skip to content

Commit 061f2e5

Browse files
committed
fix bmm experts export issue with nvfp4 scales
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent ad091e8 commit 061f2e5

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
get_weight_block_size,
6363
get_weight_scaling_factor,
6464
get_weight_scaling_factor_2,
65-
maybe_transpose_expert_weight_dimensions,
6665
postprocess_state_dict,
6766
preprocess_linear_fusion,
6867
to_quantized_weight,
@@ -293,34 +292,54 @@ def _export_quantized_weight(
293292
weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None)
294293
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
295294

296-
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
295+
# For NVFP4 quantization of expert weights, transpose to (num_experts, out_dim, in_dim)
296+
# because ModelOpt assumes in_dim is the last dimension for scaling factor computation
297297
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
298-
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
299-
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
300-
is_bmm_expert_weight = weight.dim() == 3 and any(
298+
is_expert_weight = weight.dim() == 3 and any(
301299
expert_type in type(sub_module).__name__
302300
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
303301
)
304-
weight, _ = maybe_transpose_expert_weight_dimensions(
305-
weight, is_bmm_expert_weight=is_bmm_expert_weight
306-
)
307-
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
308-
weight,
309-
block_size=block_size,
310-
weights_scaling_factor_2=weight_scale_2,
311-
)[0]
312-
313-
quantized_weight = to_quantized_weight(
314-
weight.to(dtype),
315-
weight_scale,
316-
quantization_format,
317-
weight_scale_2,
318-
block_size,
319-
)
320-
321-
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
322-
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
323-
)
302+
303+
if is_expert_weight:
304+
# Transpose from (num_experts, in_dim, out_dim) to (num_experts, out_dim, in_dim)
305+
transposed_weight = weight.transpose(-2, -1).contiguous()
306+
307+
# Compute scaling factor from transposed weight
308+
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
309+
transposed_weight,
310+
block_size=block_size,
311+
weights_scaling_factor_2=weight_scale_2,
312+
)[0]
313+
314+
# Quantize using transposed weight and scaling factor
315+
quantized_weight = to_quantized_weight(
316+
transposed_weight.to(dtype),
317+
weight_scale,
318+
quantization_format,
319+
weight_scale_2,
320+
block_size,
321+
)
322+
323+
# Transpose quantized weight back to original format (num_experts, in_dim, out_dim)
324+
quantized_weight = quantized_weight.transpose(-2, -1).contiguous()
325+
326+
# Transpose scaling factor back to match original weight dimensions
327+
weight_scale = weight_scale.transpose(-2, -1).contiguous()
328+
else:
329+
# Regular weight quantization (non-expert)
330+
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
331+
weight,
332+
block_size=block_size,
333+
weights_scaling_factor_2=weight_scale_2,
334+
)[0]
335+
336+
quantized_weight = to_quantized_weight(
337+
weight.to(dtype),
338+
weight_scale,
339+
quantization_format,
340+
weight_scale_2,
341+
block_size,
342+
)
324343
else:
325344
quantized_weight = to_quantized_weight(
326345
weight.to(dtype),

0 commit comments

Comments
 (0)