Skip to content

Commit 7acbe57

Browse files
committed
cleanup
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent c1b53cb commit 7acbe57

File tree

1 file changed

+23
-66
lines changed

1 file changed

+23
-66
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 23 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
get_weight_block_size,
6363
get_weight_scaling_factor,
6464
get_weight_scaling_factor_2,
65+
maybe_transpose_expert_weight_dimensions,
6566
postprocess_state_dict,
6667
preprocess_linear_fusion,
6768
to_quantized_weight,
@@ -292,78 +293,34 @@ def _export_quantized_weight(
292293
weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None)
293294
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
294295

295-
# For NVFP4 quantization of expert weights, transpose to (num_experts, out_dim, in_dim)
296-
# because ModelOpt assumes in_dim is the last dimension for scaling factor computation
296+
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
297297
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
298-
is_expert_weight = weight.dim() == 3 and any(
298+
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
299+
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
300+
is_bmm_expert_weight = weight.dim() == 3 and any(
299301
expert_type in type(sub_module).__name__
300302
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
301303
)
304+
weight, _ = maybe_transpose_expert_weight_dimensions(
305+
weight, is_bmm_expert_weight=is_bmm_expert_weight
306+
)
307+
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
308+
weight,
309+
block_size=block_size,
310+
weights_scaling_factor_2=weight_scale_2,
311+
)[0]
302312

303-
if is_expert_weight:
304-
# Apply BMM transposition for both Llama4TextExperts and GptOssExperts
305-
print(
306-
f"DEBUG: Original weight shape for {type(sub_module).__name__}.{weight_name}: {weight.shape}"
307-
)
308-
309-
# Transpose from (num_experts, in_dim, out_dim) to (num_experts, out_dim, in_dim)
310-
transposed_weight = weight.transpose(-2, -1).contiguous()
311-
print(f"DEBUG: Transposed weight shape: {transposed_weight.shape}")
312-
313-
# Compute scaling factor from transposed weight
314-
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
315-
transposed_weight,
316-
block_size=block_size,
317-
weights_scaling_factor_2=weight_scale_2,
318-
)[0]
319-
print(f"DEBUG: Scaling factor shape from transposed weight: {weight_scale.shape}")
320-
321-
# Test: what would scaling factor be if we transpose it back?
322-
if weight_scale.dim() == 3:
323-
transposed_back_scale = weight_scale.transpose(-2, -1)
324-
print(
325-
f"DEBUG: Scaling factor shape if transposed back: {transposed_back_scale.shape}"
326-
)
327-
328-
# Quantize using transposed weight and scaling factor
329-
quantized_weight = to_quantized_weight(
330-
transposed_weight.to(dtype),
331-
weight_scale,
332-
quantization_format,
333-
weight_scale_2,
334-
block_size,
335-
)
336-
337-
# Transpose quantized weight back to original format
338-
quantized_weight = quantized_weight.transpose(-2, -1).contiguous()
339-
print(f"DEBUG: Final quantized weight shape: {quantized_weight.shape}")
313+
quantized_weight = to_quantized_weight(
314+
weight.to(dtype),
315+
weight_scale,
316+
quantization_format,
317+
weight_scale_2,
318+
block_size,
319+
)
340320

341-
# Transpose scaling factor back to match original weight dimensions
342-
if weight_scale.dim() == 3:
343-
weight_scale = weight_scale.transpose(-2, -1).contiguous()
344-
print(
345-
f"DEBUG: Final scaling factor shape (after transposing back): {weight_scale.shape}"
346-
)
347-
else:
348-
print(
349-
f"DEBUG: Final scaling factor shape (no transpose needed): {weight_scale.shape}"
350-
)
351-
print("=" * 80)
352-
else:
353-
# Regular weight quantization (non-expert)
354-
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
355-
weight,
356-
block_size=block_size,
357-
weights_scaling_factor_2=weight_scale_2,
358-
)[0]
359-
360-
quantized_weight = to_quantized_weight(
361-
weight.to(dtype),
362-
weight_scale,
363-
quantization_format,
364-
weight_scale_2,
365-
block_size,
366-
)
321+
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
322+
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
323+
)
367324
else:
368325
quantized_weight = to_quantized_weight(
369326
weight.to(dtype),

0 commit comments

Comments
 (0)