Skip to content

Commit 5fcef97

Browse files
committed
added support for nvfp4 export
Signed-off-by: Suguna Velury <[email protected]>
1 parent 3cf89cf commit 5fcef97

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ def main(args):
364364
)
365365
mts.export(model)
366366

367-
if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES:
367+
if (
368+
args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES
369+
) and not model_is_already_quantized:
368370
if "awq" in args.qformat:
369371
print(
370372
"\n####\nAWQ calibration could take longer than other calibration methods. "
@@ -393,6 +395,9 @@ def main(args):
393395
sample_input_single_batch = None
394396

395397
run_auto_quant = args.auto_quantize_bits is not None
398+
print("DEBUG LOG: Entereing here")
399+
for k, v in model.state_dict().items():
400+
print(k, v.shape, v.dtype, v.device)
396401

397402
args.batch_size = get_max_batch_size(
398403
model,
@@ -635,7 +640,6 @@ def output_decode(generated_ids, input_shape):
635640
"They will be set at deployment time."
636641
)
637642

638-
print("DEBUG LOG: Calling unified export hf checkpoint")
639643
export_hf_checkpoint(
640644
full_model,
641645
export_dir=export_path,

modelopt/torch/export/quant_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270
QUANTIZATION_NVFP4_AWQ,
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273+
if hasattr(weight_quantizer, "_scale"):
274+
# In this case, weight must be a QTensorWrapper
275+
original_shape = weight.metadata["shape"]
276+
ws = NVFP4QTensor.get_modelopt_weights_scaling_factor(
277+
weight_quantizer._scale, original_shape
278+
)
279+
print(f"weight_quantizer._scale: {ws.shape}")
280+
return ws
281+
273282
return NVFP4QTensor.get_weights_scaling_factor(
274283
weight,
275284
weight_quantizer.block_sizes[-1],
@@ -612,8 +621,6 @@ def process_layer_quant_config(layer_config_dict):
612621
# Get the corresponding AWQ block size
613622
block_size_value = layer_config_dict.get(awq_key, 0)
614623

615-
# print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}")
616-
617624
if v == "fp8":
618625
layer_config = {"quant_algo": "FP8"}
619626
elif v == "fp8_pc_pt":
@@ -1088,6 +1095,9 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
10881095
block_size = get_weight_block_size(module)
10891096

10901097
# Construct per layer config dictionary
1098+
if block_size == 0 and quantization_format != QUANTIZATION_FP8:
1099+
continue
1100+
10911101
layer_config_dict[name + ".quantization"] = quantization_format
10921102
layer_config_dict[name + ".awq_block_size"] = block_size
10931103

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,11 @@ def export_hf_checkpoint(
542542
model.base_model.save_pretrained(
543543
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
544544
)
545-
546-
model.save_pretrained(
547-
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
548-
)
545+
model.save_pretrained(export_dir, save_modelopt_state=save_modelopt_state)
546+
else:
547+
model.save_pretrained(
548+
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
549+
)
549550

550551
original_config = f"{base_export_dir}/config.json"
551552
config_data = {}

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,27 @@ def get_weights_scaling_factor_2(cls, input: torch.Tensor):
9494
"""Returns per tensor weight scaling factor."""
9595
return reduce_amax(input).float() / (6.0 * 448.0)
9696

97+
@classmethod
98+
def get_modelopt_weights_scaling_factor(cls, weight_scaling_factor: torch.Tensor, weight_shape):
99+
"""Returns the modelopt weights scaling factor if the quantization is done by trtllm."""
100+
if weight_scaling_factor.dtype == torch.float8_e4m3fn:
101+
return weight_scaling_factor
102+
103+
if weight_scaling_factor.dtype == torch.uint8 and weight_scaling_factor.ndim == 1:
104+
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
105+
try:
106+
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
107+
cutlass_fp4_scale_to_modelopt_fp4_scale,
108+
)
109+
110+
return cutlass_fp4_scale_to_modelopt_fp4_scale(
111+
weight_scaling_factor, weight_shape[-2:]
112+
)
113+
except ImportError as e:
114+
raise ImportError(
115+
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
116+
) from e
117+
97118
@classmethod
98119
def get_activation_scaling_factor(cls, quantizer):
99120
"""Returns the activation scaling factor for export."""

0 commit comments

Comments
 (0)