|
62 | 62 | get_weight_block_size,
|
63 | 63 | get_weight_scaling_factor,
|
64 | 64 | get_weight_scaling_factor_2,
|
| 65 | + maybe_transpose_expert_weight_dimensions, |
65 | 66 | postprocess_state_dict,
|
66 | 67 | preprocess_linear_fusion,
|
67 | 68 | to_quantized_weight,
|
@@ -292,78 +293,34 @@ def _export_quantized_weight(
|
292 | 293 | weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None)
|
293 | 294 | weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
|
294 | 295 |
|
295 |
| - # For NVFP4 quantization of expert weights, transpose to (num_experts, out_dim, in_dim) |
296 |
| - # because ModelOpt assumes in_dim is the last dimension for scaling factor computation |
| 296 | + # Transpose weight for bmm-style expert quantization (llama4, gpt-oss) |
297 | 297 | if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
|
298 |
| - is_expert_weight = weight.dim() == 3 and any( |
| 298 | + # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) |
| 299 | + # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization |
| 300 | + is_bmm_expert_weight = weight.dim() == 3 and any( |
299 | 301 | expert_type in type(sub_module).__name__
|
300 | 302 | for expert_type in ["Llama4TextExperts", "GptOssExperts"]
|
301 | 303 | )
|
| 304 | + weight, _ = maybe_transpose_expert_weight_dimensions( |
| 305 | + weight, is_bmm_expert_weight=is_bmm_expert_weight |
| 306 | + ) |
| 307 | + weight_scale = NVFP4QTensor.get_weights_scaling_factor( |
| 308 | + weight, |
| 309 | + block_size=block_size, |
| 310 | + weights_scaling_factor_2=weight_scale_2, |
| 311 | + )[0] |
302 | 312 |
|
303 |
| - if is_expert_weight: |
304 |
| - # Apply BMM transposition for both Llama4TextExperts and GptOssExperts |
305 |
| - print( |
306 |
| - f"DEBUG: Original weight shape for {type(sub_module).__name__}.{weight_name}: {weight.shape}" |
307 |
| - ) |
308 |
| - |
309 |
| - # Transpose from (num_experts, in_dim, out_dim) to (num_experts, out_dim, in_dim) |
310 |
| - transposed_weight = weight.transpose(-2, -1).contiguous() |
311 |
| - print(f"DEBUG: Transposed weight shape: {transposed_weight.shape}") |
312 |
| - |
313 |
| - # Compute scaling factor from transposed weight |
314 |
| - weight_scale = NVFP4QTensor.get_weights_scaling_factor( |
315 |
| - transposed_weight, |
316 |
| - block_size=block_size, |
317 |
| - weights_scaling_factor_2=weight_scale_2, |
318 |
| - )[0] |
319 |
| - print(f"DEBUG: Scaling factor shape from transposed weight: {weight_scale.shape}") |
320 |
| - |
321 |
| - # Test: what would scaling factor be if we transpose it back? |
322 |
| - if weight_scale.dim() == 3: |
323 |
| - transposed_back_scale = weight_scale.transpose(-2, -1) |
324 |
| - print( |
325 |
| - f"DEBUG: Scaling factor shape if transposed back: {transposed_back_scale.shape}" |
326 |
| - ) |
327 |
| - |
328 |
| - # Quantize using transposed weight and scaling factor |
329 |
| - quantized_weight = to_quantized_weight( |
330 |
| - transposed_weight.to(dtype), |
331 |
| - weight_scale, |
332 |
| - quantization_format, |
333 |
| - weight_scale_2, |
334 |
| - block_size, |
335 |
| - ) |
336 |
| - |
337 |
| - # Transpose quantized weight back to original format |
338 |
| - quantized_weight = quantized_weight.transpose(-2, -1).contiguous() |
339 |
| - print(f"DEBUG: Final quantized weight shape: {quantized_weight.shape}") |
| 313 | + quantized_weight = to_quantized_weight( |
| 314 | + weight.to(dtype), |
| 315 | + weight_scale, |
| 316 | + quantization_format, |
| 317 | + weight_scale_2, |
| 318 | + block_size, |
| 319 | + ) |
340 | 320 |
|
341 |
| - # Transpose scaling factor back to match original weight dimensions |
342 |
| - if weight_scale.dim() == 3: |
343 |
| - weight_scale = weight_scale.transpose(-2, -1).contiguous() |
344 |
| - print( |
345 |
| - f"DEBUG: Final scaling factor shape (after transposing back): {weight_scale.shape}" |
346 |
| - ) |
347 |
| - else: |
348 |
| - print( |
349 |
| - f"DEBUG: Final scaling factor shape (no transpose needed): {weight_scale.shape}" |
350 |
| - ) |
351 |
| - print("=" * 80) |
352 |
| - else: |
353 |
| - # Regular weight quantization (non-expert) |
354 |
| - weight_scale = NVFP4QTensor.get_weights_scaling_factor( |
355 |
| - weight, |
356 |
| - block_size=block_size, |
357 |
| - weights_scaling_factor_2=weight_scale_2, |
358 |
| - )[0] |
359 |
| - |
360 |
| - quantized_weight = to_quantized_weight( |
361 |
| - weight.to(dtype), |
362 |
| - weight_scale, |
363 |
| - quantization_format, |
364 |
| - weight_scale_2, |
365 |
| - block_size, |
366 |
| - ) |
| 321 | + quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions( |
| 322 | + quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight |
| 323 | + ) |
367 | 324 | else:
|
368 | 325 | quantized_weight = to_quantized_weight(
|
369 | 326 | weight.to(dtype),
|
|
0 commit comments