Skip to content

Commit 33fdcf3

Browse files
committed
added support for nvfp4 export
Signed-off-by: Suguna Velury <[email protected]>
1 parent e18c323 commit 33fdcf3

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def main(args):
357357
)
358358
mts.export(model)
359359

360-
if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES:
360+
if (
361+
args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES
362+
) and not model_is_already_quantized:
361363
if "awq" in args.qformat:
362364
print(
363365
"\n####\nAWQ calibration could take longer than other calibration methods. "
@@ -386,6 +388,9 @@ def main(args):
386388
sample_input_single_batch = None
387389

388390
run_auto_quant = args.auto_quantize_bits is not None
391+
print("DEBUG LOG: Entereing here")
392+
for k, v in model.state_dict().items():
393+
print(k, v.shape, v.dtype, v.device)
389394

390395
args.batch_size = get_max_batch_size(
391396
model,
@@ -628,7 +633,6 @@ def output_decode(generated_ids, input_shape):
628633
"They will be set at deployment time."
629634
)
630635

631-
print("DEBUG LOG: Calling unified export hf checkpoint")
632636
export_hf_checkpoint(
633637
full_model,
634638
export_dir=export_path,

modelopt/torch/export/quant_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
269269
QUANTIZATION_NVFP4_AWQ,
270270
QUANTIZATION_W4A8_NVFP4_FP8,
271271
]:
272+
if hasattr(weight_quantizer, "_scale"):
273+
# In this case, weight must be a QTensorWrapper
274+
original_shape = weight.metadata["shape"]
275+
ws = NVFP4QTensor.get_modelopt_weights_scaling_factor(
276+
weight_quantizer._scale, original_shape
277+
)
278+
print(f"weight_quantizer._scale: {ws.shape}")
279+
return ws
280+
272281
return NVFP4QTensor.get_weights_scaling_factor(
273282
weight,
274283
weight_quantizer.block_sizes[-1],
@@ -608,8 +617,6 @@ def process_layer_quant_config(layer_config_dict):
608617
# Get the corresponding AWQ block size
609618
block_size_value = layer_config_dict.get(awq_key, 0)
610619

611-
# print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}")
612-
613620
if v == "fp8":
614621
layer_config = {"quant_algo": "FP8"}
615622
elif v == "fp8_pc_pt":
@@ -1082,6 +1089,9 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
10821089
block_size = get_weight_block_size(module)
10831090

10841091
# Construct per layer config dictionary
1092+
if block_size == 0 and quantization_format != QUANTIZATION_FP8:
1093+
continue
1094+
10851095
layer_config_dict[name + ".quantization"] = quantization_format
10861096
layer_config_dict[name + ".awq_block_size"] = block_size
10871097

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,11 @@ def export_hf_checkpoint(
538538
model.base_model.save_pretrained(
539539
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
540540
)
541-
542-
model.save_pretrained(
543-
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
544-
)
541+
model.save_pretrained(export_dir, save_modelopt_state=save_modelopt_state)
542+
else:
543+
model.save_pretrained(
544+
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
545+
)
545546

546547
original_config = f"{base_export_dir}/config.json"
547548
config_data = {}

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,27 @@ def get_weights_scaling_factor_2(cls, input: torch.Tensor):
9494
"""Returns per tensor weight scaling factor."""
9595
return reduce_amax(input).float() / (6.0 * 448.0)
9696

97+
@classmethod
98+
def get_modelopt_weights_scaling_factor(cls, weight_scaling_factor: torch.Tensor, weight_shape):
99+
"""Returns the modelopt weights scaling factor if the quantization is done by trtllm."""
100+
if weight_scaling_factor.dtype == torch.float8_e4m3fn:
101+
return weight_scaling_factor
102+
103+
if weight_scaling_factor.dtype == torch.uint8 and weight_scaling_factor.ndim == 1:
104+
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
105+
try:
106+
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
107+
cutlass_fp4_scale_to_modelopt_fp4_scale,
108+
)
109+
110+
return cutlass_fp4_scale_to_modelopt_fp4_scale(
111+
weight_scaling_factor, weight_shape[-2:]
112+
)
113+
except ImportError as e:
114+
raise ImportError(
115+
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
116+
) from e
117+
97118
@classmethod
98119
def get_activation_scaling_factor(cls, quantizer):
99120
"""Returns the activation scaling factor for export."""

0 commit comments

Comments
 (0)