5050 KV_CACHE_NVFP4_AFFINE ,
5151 QUANTIZATION_FP8 ,
5252 QUANTIZATION_FP8_PB_REAL ,
53+ QUANTIZATION_FP8_PC_PT ,
5354 QUANTIZATION_NONE ,
5455 QUANTIZATION_NVFP4 ,
5556 QUANTIZATION_NVFP4_AWQ ,
@@ -323,13 +324,15 @@ def _export_quantized_weight(
323324 weight_scale_2 : torch .Tensor | None = getattr (sub_module , quantizer_attrs .weight_scale_2 , None )
324325
325326 # Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
327+ # Check if this is a BMM-style expert weight that needs transposition
328+ is_bmm_expert_weight = weight .dim () == 3 and any (
329+ expert_type in type (sub_module ).__name__
330+ for expert_type in ["Llama4TextExperts" , "GptOssExperts" ]
331+ )
332+
326333 if quantization_format in [QUANTIZATION_NVFP4 , QUANTIZATION_NVFP4_AWQ ]:
327334 # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
328335 # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
329- is_bmm_expert_weight = weight .dim () == 3 and any (
330- expert_type in type (sub_module ).__name__
331- for expert_type in ["Llama4TextExperts" , "GptOssExperts" ]
332- )
333336 weight , _ = maybe_transpose_expert_weight_dimensions (
334337 weight , is_bmm_expert_weight = is_bmm_expert_weight
335338 )
@@ -350,6 +353,26 @@ def _export_quantized_weight(
350353 quantized_weight , weight_scale = maybe_transpose_expert_weight_dimensions (
351354 quantized_weight , weight_scale , is_bmm_expert_weight = is_bmm_expert_weight
352355 )
356+ elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight :
357+ # For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
358+ # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
359+ # weight_scale remains (num_experts, output_dim) for per-channel quantization
360+ weight , _ = maybe_transpose_expert_weight_dimensions (
361+ weight , is_bmm_expert_weight = is_bmm_expert_weight
362+ )
363+
364+ quantized_weight = to_quantized_weight (
365+ weight .to (dtype ),
366+ weight_scale ,
367+ quantization_format ,
368+ weight_scale_2 ,
369+ block_size ,
370+ )
371+
372+ # Transpose back to original BMM format
373+ quantized_weight , _ = maybe_transpose_expert_weight_dimensions (
374+ quantized_weight , is_bmm_expert_weight = is_bmm_expert_weight
375+ )
353376 else :
354377 quantized_weight = to_quantized_weight (
355378 weight .to (dtype ),
0 commit comments