Skip to content

Commit c1b53cb

Browse files
committed
debug
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent a27920d commit c1b53cb

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,29 @@ def _export_quantized_weight(
301301
)
302302

303303
if is_expert_weight:
304+
# Apply BMM transposition for both Llama4TextExperts and GptOssExperts
305+
print(
306+
f"DEBUG: Original weight shape for {type(sub_module).__name__}.{weight_name}: {weight.shape}"
307+
)
308+
304309
# Transpose from (num_experts, in_dim, out_dim) to (num_experts, out_dim, in_dim)
305310
transposed_weight = weight.transpose(-2, -1).contiguous()
311+
print(f"DEBUG: Transposed weight shape: {transposed_weight.shape}")
306312

307313
# Compute scaling factor from transposed weight
308314
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
309315
transposed_weight,
310316
block_size=block_size,
311317
weights_scaling_factor_2=weight_scale_2,
312318
)[0]
319+
print(f"DEBUG: Scaling factor shape from transposed weight: {weight_scale.shape}")
320+
321+
# Test: what would scaling factor be if we transpose it back?
322+
if weight_scale.dim() == 3:
323+
transposed_back_scale = weight_scale.transpose(-2, -1)
324+
print(
325+
f"DEBUG: Scaling factor shape if transposed back: {transposed_back_scale.shape}"
326+
)
313327

314328
# Quantize using transposed weight and scaling factor
315329
quantized_weight = to_quantized_weight(
@@ -320,11 +334,21 @@ def _export_quantized_weight(
320334
block_size,
321335
)
322336

323-
# Transpose quantized weight back to original format (num_experts, in_dim, out_dim)
337+
# Transpose quantized weight back to original format
324338
quantized_weight = quantized_weight.transpose(-2, -1).contiguous()
339+
print(f"DEBUG: Final quantized weight shape: {quantized_weight.shape}")
325340

326341
# Transpose scaling factor back to match original weight dimensions
327-
weight_scale = weight_scale.transpose(-2, -1).contiguous()
342+
if weight_scale.dim() == 3:
343+
weight_scale = weight_scale.transpose(-2, -1).contiguous()
344+
print(
345+
f"DEBUG: Final scaling factor shape (after transposing back): {weight_scale.shape}"
346+
)
347+
else:
348+
print(
349+
f"DEBUG: Final scaling factor shape (no transpose needed): {weight_scale.shape}"
350+
)
351+
print("=" * 80)
328352
else:
329353
# Regular weight quantization (non-expert)
330354
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
@@ -351,6 +375,10 @@ def _export_quantized_weight(
351375

352376
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
353377

378+
# Register the corrected weight_scale as a buffer
379+
if weight_scale is not None:
380+
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
381+
354382

355383
def _export_hf_checkpoint(
356384
model: nn.Module, dtype: torch.dtype | None = None

0 commit comments

Comments
 (0)