@@ -301,15 +301,29 @@ def _export_quantized_weight(
301
301
)
302
302
303
303
if is_expert_weight :
304
+ # Apply BMM transposition for both Llama4TextExperts and GptOssExperts
305
+ print (
306
+ f"DEBUG: Original weight shape for { type (sub_module ).__name__ } .{ weight_name } : { weight .shape } "
307
+ )
308
+
304
309
# Transpose from (num_experts, in_dim, out_dim) to (num_experts, out_dim, in_dim)
305
310
transposed_weight = weight .transpose (- 2 , - 1 ).contiguous ()
311
+ print (f"DEBUG: Transposed weight shape: { transposed_weight .shape } " )
306
312
307
313
# Compute scaling factor from transposed weight
308
314
weight_scale = NVFP4QTensor .get_weights_scaling_factor (
309
315
transposed_weight ,
310
316
block_size = block_size ,
311
317
weights_scaling_factor_2 = weight_scale_2 ,
312
318
)[0 ]
319
+ print (f"DEBUG: Scaling factor shape from transposed weight: { weight_scale .shape } " )
320
+
321
+ # Test: what would scaling factor be if we transpose it back?
322
+ if weight_scale .dim () == 3 :
323
+ transposed_back_scale = weight_scale .transpose (- 2 , - 1 )
324
+ print (
325
+ f"DEBUG: Scaling factor shape if transposed back: { transposed_back_scale .shape } "
326
+ )
313
327
314
328
# Quantize using transposed weight and scaling factor
315
329
quantized_weight = to_quantized_weight (
@@ -320,11 +334,21 @@ def _export_quantized_weight(
320
334
block_size ,
321
335
)
322
336
323
- # Transpose quantized weight back to original format (num_experts, in_dim, out_dim)
337
+ # Transpose quantized weight back to original format
324
338
quantized_weight = quantized_weight .transpose (- 2 , - 1 ).contiguous ()
339
+ print (f"DEBUG: Final quantized weight shape: { quantized_weight .shape } " )
325
340
326
341
# Transpose scaling factor back to match original weight dimensions
327
- weight_scale = weight_scale .transpose (- 2 , - 1 ).contiguous ()
342
+ if weight_scale .dim () == 3 :
343
+ weight_scale = weight_scale .transpose (- 2 , - 1 ).contiguous ()
344
+ print (
345
+ f"DEBUG: Final scaling factor shape (after transposing back): { weight_scale .shape } "
346
+ )
347
+ else :
348
+ print (
349
+ f"DEBUG: Final scaling factor shape (no transpose needed): { weight_scale .shape } "
350
+ )
351
+ print ("=" * 80 )
328
352
else :
329
353
# Regular weight quantization (non-expert)
330
354
weight_scale = NVFP4QTensor .get_weights_scaling_factor (
@@ -351,6 +375,10 @@ def _export_quantized_weight(
351
375
352
376
setattr (sub_module , weight_name , nn .Parameter (quantized_weight , requires_grad = False ))
353
377
378
+ # Register the corrected weight_scale as a buffer
379
+ if weight_scale is not None :
380
+ sub_module .register_buffer (quantizer_attrs .weight_scale , weight_scale )
381
+
354
382
355
383
def _export_hf_checkpoint (
356
384
model : nn .Module , dtype : torch .dtype | None = None
0 commit comments