Skip to content

Commit 0b19255

Browse files
committed
minor update
Signed-off-by: Suguna Velury <[email protected]>
1 parent 3b7ef44 commit 0b19255

File tree

6 files changed

+23
-43
lines changed

6 files changed

+23
-43
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def main(args):
482482
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
483483
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
484484

485-
if not model_is_already_quantized and calibration_only:
485+
if not model_is_already_quantized or calibration_only:
486486
# Only run single sample for preview
487487
input_ids = next(iter(calib_dataloader))[
488488
"input_features" if model_type == "whisper" else "input_ids"
@@ -772,12 +772,6 @@ def output_decode(generated_ids, input_shape):
772772
default=None,
773773
type=str,
774774
)
775-
parser.add_argument(
776-
"--qlora",
777-
help="Specify the model to be exported is a QLoRA model trained using modelopt.",
778-
default=False,
779-
action="store_true",
780-
)
781775

782776
args = parser.parse_args()
783777

examples/llm_qat/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,9 @@ To perform QLoRA training, run:
348348
After performing QLoRA training the final checkpoint can be exported for deployment with vLLM using the following command.
349349

350350
```sh
351-
cd ../llm_ptq
352-
353-
python hf_ptq.py \
351+
python export.py \
354352
--pyt_ckpt_path llama3-fp4-qlora \
355-
--qformat nvfp4 \
356353
--export_dir llama3-fp4-qlora-hf \
357-
--qlora
358354

359355
```
360356

examples/llm_qat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def main(args):
6868
base_model_dir.mkdir(parents=True, exist_ok=True)
6969

7070
try:
71-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_lora=True)
71+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True)
7272

7373
with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file:
7474
json.dump(hf_quant_config, file, indent=4)

modelopt/torch/export/unified_export_hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _export_quantized_weight(
338338

339339

340340
def _export_hf_checkpoint(
341-
model: nn.Module, dtype: torch.dtype | None = None, is_lora: bool = False
341+
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False
342342
) -> tuple[dict[str, Any], dict[str, Any]]:
343343
"""Exports the torch model to the packed checkpoint with original HF naming.
344344
@@ -431,7 +431,7 @@ def _export_hf_checkpoint(
431431
# Resmooth and requantize fused layers
432432
# TODO: Handle mixed precision
433433
# TODO: Support requantize and resmooth for modelopt-trained LoRA models
434-
if not is_lora:
434+
if not is_modelopt_qlora:
435435
requantize_resmooth_fused_llm_layers(model)
436436

437437
# Remove all hooks from the model
@@ -491,7 +491,7 @@ def _export_hf_checkpoint(
491491
quantized_state_dict = model.state_dict()
492492

493493
quantized_state_dict = postprocess_state_dict(
494-
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_lora
494+
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
495495
)
496496

497497
# Check if any layers are quantized

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,13 @@ def forward_loop(model):
209209
print_rank_0("Quantizing the model...")
210210
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]
211211

212-
# Save modelopt state before compression
212+
# Save modelopt state before compression. This is used to later export the model for deployment.
213213
modelopt_state = mto.modelopt_state(self.model)
214214
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
215-
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calibration.pth")
215+
torch.save(modelopt_state, f"{self.args.output_dir}/modelopt_state_calib.pth")
216216

217217
print_rank_0(
218-
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calibration.pth'}"
218+
f"Saved modelopt state before compression to {f'{self.args.output_dir}/modelopt_state_calib.pth'}"
219219
)
220220

221221
if getattr(self.quant_args, "compress", False):

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,27 +94,6 @@ def get_weights_scaling_factor_2(cls, input: torch.Tensor):
9494
"""Returns per tensor weight scaling factor."""
9595
return reduce_amax(input).float() / (6.0 * 448.0)
9696

97-
@classmethod
98-
def get_modelopt_weights_scaling_factor(cls, weight_scaling_factor: torch.Tensor, weight_shape):
99-
"""Returns the modelopt weights scaling factor if the quantization is done by trtllm."""
100-
if weight_scaling_factor.dtype == torch.float8_e4m3fn:
101-
return weight_scaling_factor
102-
103-
if weight_scaling_factor.dtype == torch.uint8 and weight_scaling_factor.ndim == 1:
104-
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
105-
try:
106-
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
107-
cutlass_fp4_scale_to_modelopt_fp4_scale,
108-
)
109-
110-
return cutlass_fp4_scale_to_modelopt_fp4_scale(
111-
weight_scaling_factor, weight_shape[-2:]
112-
)
113-
except ImportError as e:
114-
raise ImportError(
115-
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
116-
) from e
117-
11897
@classmethod
11998
def get_activation_scaling_factor(cls, quantizer):
12099
"""Returns the activation scaling factor for export."""
@@ -270,9 +249,20 @@ def _unpack_tensor(input: torch.Tensor):
270249
return unpacked.reshape(unpacked_shape)
271250

272251
# Get scales from kwargs
273-
kwarg["scale"] = self.get_modelopt_weights_scaling_factor(
274-
kwarg["scale"], self.metadata["shape"]
275-
)
252+
if kwarg["scale"].dtype == torch.uint8 and kwarg["scale"].ndim == 1:
253+
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
254+
try:
255+
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
256+
cutlass_fp4_scale_to_modelopt_fp4_scale,
257+
)
258+
259+
kwarg["scale"] = cutlass_fp4_scale_to_modelopt_fp4_scale(
260+
kwarg["scale"], self.metadata["shape"][-2:]
261+
)
262+
except ImportError as e:
263+
raise ImportError(
264+
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
265+
) from e
276266

277267
if fast:
278268
from ..triton.fp4_kernel import fp4_dequantize

0 commit comments

Comments
 (0)