Skip to content

Commit fbd5417

Browse files
committed
fix bmm style moe export in fp8_pc_pt recipe
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 484b95a commit fbd5417

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,38 @@ def to_quantized_weight(
775775
)[0]._quantized_data
776776

777777
if quantization == QUANTIZATION_FP8_PC_PT:
778-
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
778+
if weight.dim() == 3:
779+
# for MOE stacked weights
780+
# For standard MoE: weight (num_experts, output_dim, input_dim)
781+
# scale (num_experts, output_dim)
782+
# For BMM-style transposed experts: weight (num_experts, output_dim, input_dim)
783+
# scale (num_experts, input_dim)
784+
785+
# Handle different scale tensor shapes
786+
if weights_scaling_factor.dim() == 1:
787+
# Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
788+
return (weight / weights_scaling_factor[:, None, None]).to(torch.float8_e4m3fn)
789+
elif weights_scaling_factor.dim() == 2:
790+
# Per-channel scaling: check which dimension matches
791+
if weights_scaling_factor.shape[-1] == weight.shape[-1]:
792+
# Scale matches last dim (input_dim) - BMM-style transposed case
793+
# (num_experts, input_dim) -> (num_experts, 1, input_dim)
794+
return (weight / weights_scaling_factor.unsqueeze(-2)).to(torch.float8_e4m3fn)
795+
elif weights_scaling_factor.shape[-1] == weight.shape[-2]:
796+
# Scale matches second-to-last dim (output_dim) - standard MoE case
797+
# (num_experts, output_dim) -> (num_experts, output_dim, 1)
798+
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
799+
else:
800+
# Shape mismatch - try to infer correct broadcasting
801+
raise ValueError(
802+
f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
803+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
804+
)
805+
else:
806+
raise ValueError(
807+
f"Unexpected scaling factor dimension for 3D weight: {weights_scaling_factor.dim()}"
808+
)
809+
return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn)
779810

780811
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
781812
return pack_int4_in_uint8(weight, weights_scaling_factor)

modelopt/torch/export/unified_export_hf.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
KV_CACHE_NVFP4_AFFINE,
5151
QUANTIZATION_FP8,
5252
QUANTIZATION_FP8_PB_REAL,
53+
QUANTIZATION_FP8_PC_PT,
5354
QUANTIZATION_NONE,
5455
QUANTIZATION_NVFP4,
5556
QUANTIZATION_NVFP4_AWQ,
@@ -323,13 +324,15 @@ def _export_quantized_weight(
323324
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
324325

325326
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
327+
# Check if this is a BMM-style expert weight that needs transposition
328+
is_bmm_expert_weight = weight.dim() == 3 and any(
329+
expert_type in type(sub_module).__name__
330+
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
331+
)
332+
326333
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
327334
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
328335
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
329-
is_bmm_expert_weight = weight.dim() == 3 and any(
330-
expert_type in type(sub_module).__name__
331-
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
332-
)
333336
weight, _ = maybe_transpose_expert_weight_dimensions(
334337
weight, is_bmm_expert_weight=is_bmm_expert_weight
335338
)
@@ -350,6 +353,26 @@ def _export_quantized_weight(
350353
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
351354
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
352355
)
356+
elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight:
357+
# For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
358+
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
359+
# weight_scale remains (num_experts, output_dim) for per-channel quantization
360+
weight, _ = maybe_transpose_expert_weight_dimensions(
361+
weight, is_bmm_expert_weight=is_bmm_expert_weight
362+
)
363+
364+
quantized_weight = to_quantized_weight(
365+
weight.to(dtype),
366+
weight_scale,
367+
quantization_format,
368+
weight_scale_2,
369+
block_size,
370+
)
371+
372+
# Transpose back to original BMM format
373+
quantized_weight, _ = maybe_transpose_expert_weight_dimensions(
374+
quantized_weight, is_bmm_expert_weight=is_bmm_expert_weight
375+
)
353376
else:
354377
quantized_weight = to_quantized_weight(
355378
weight.to(dtype),

0 commit comments

Comments
 (0)