Skip to content

Commit d00a7f6

Browse files
committed
e2e checkpoint tested for nvfp4 and fp8
Signed-off-by: Suguna Velury <[email protected]>
1 parent f83f934 commit d00a7f6

File tree

6 files changed

+84
-81
lines changed

6 files changed

+84
-81
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,11 @@ def get_dtype(dtype):
126126

127127
def get_lora_model(
128128
ckpt_path: str,
129-
device="cuda",
129+
device_map="cuda",
130130
):
131131
"""
132132
Loads a QLoRA model that has been trained using modelopt trainer.
133133
"""
134-
device_map = "auto"
135-
if device == "cpu":
136-
device_map = "cpu"
137-
138134
# Load model with adapters
139135
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
140136

@@ -156,13 +152,18 @@ def get_model(
156152
trust_remote_code=False,
157153
use_seq_device_map=False,
158154
attn_implementation=None,
155+
is_lora=False,
159156
):
160157
print(f"Initializing model from {ckpt_path}")
161158

162159
device_map = "auto"
163160
if device == "cpu":
164161
device_map = "cpu"
165162

163+
if is_lora:
164+
model = get_lora_model(ckpt_path, device_map)
165+
return model
166+
166167
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
167168
if attn_implementation is not None:
168169
config_kwargs["attn_implementation"] = attn_implementation

examples/llm_ptq/hf_ptq.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,7 @@
2323
import numpy as np
2424
import torch
2525
from accelerate.hooks import remove_hook_from_module
26-
from example_utils import (
27-
apply_kv_cache_quant,
28-
get_lora_model,
29-
get_model,
30-
get_processor,
31-
get_tokenizer,
32-
is_enc_dec,
33-
)
26+
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
3427
from transformers import (
3528
AutoConfig,
3629
AutoModelForCausalLM,
@@ -238,20 +231,15 @@ def main(args):
238231
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
239232
calibration_only = False
240233
if not args.low_memory_mode:
241-
if args.lora:
242-
model = get_lora_model(
243-
args.pyt_ckpt_path,
244-
args.device,
245-
)
246-
else:
247-
model = get_model(
248-
args.pyt_ckpt_path,
249-
args.device,
250-
gpu_mem_percentage=args.gpu_max_mem_percentage,
251-
trust_remote_code=args.trust_remote_code,
252-
use_seq_device_map=args.use_seq_device_map,
253-
attn_implementation=args.attn_implementation,
254-
)
234+
model = get_model(
235+
args.pyt_ckpt_path,
236+
args.device,
237+
gpu_mem_percentage=args.gpu_max_mem_percentage,
238+
trust_remote_code=args.trust_remote_code,
239+
use_seq_device_map=args.use_seq_device_map,
240+
attn_implementation=args.attn_implementation,
241+
is_lora=args.lora,
242+
)
255243
else:
256244
assert args.qformat in QUANT_CFG_CHOICES, (
257245
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -388,9 +376,6 @@ def main(args):
388376
sample_input_single_batch = None
389377

390378
run_auto_quant = args.auto_quantize_bits is not None
391-
print("DEBUG LOG: Entereing here")
392-
for k, v in model.state_dict().items():
393-
print(k, v.shape, v.dtype, v.device)
394379

395380
args.batch_size = get_max_batch_size(
396381
model,
@@ -489,7 +474,7 @@ def main(args):
489474
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
490475
)
491476

492-
if not model_is_already_quantized or calibration_only:
477+
if calibration_only:
493478
# Only run single sample for preview
494479
input_ids = next(iter(calib_dataloader))[
495480
"input_features" if model_type == "whisper" else "input_ids"
@@ -563,7 +548,12 @@ def output_decode(generated_ids, input_shape):
563548

564549
else:
565550
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
566-
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
551+
if model_is_already_quantized:
552+
warnings.warn(
553+
"Skipping quantization: Model is already quantized. Exporting the model..."
554+
)
555+
else:
556+
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
567557

568558
with torch.inference_mode():
569559
if model_type is None:
@@ -636,6 +626,7 @@ def output_decode(generated_ids, input_shape):
636626
export_hf_checkpoint(
637627
full_model,
638628
export_dir=export_path,
629+
is_modelopt_trained_lora=args.lora,
639630
)
640631

641632
# Restore default padding and export the tokenizer as well.

modelopt/torch/export/quant_utils.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -269,23 +269,28 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
269269
QUANTIZATION_NVFP4_AWQ,
270270
QUANTIZATION_W4A8_NVFP4_FP8,
271271
]:
272+
# If scale is already registered, indicates weights are already compressed.
273+
# We convert to modelopt scale if necessary and return
272274
if hasattr(weight_quantizer, "_scale"):
273275
return NVFP4QTensor.get_modelopt_weights_scaling_factor(
274276
weight_quantizer._scale, weight.metadata["shape"]
275277
)
276-
277-
return NVFP4QTensor.get_weights_scaling_factor(
278-
weight,
279-
weight_quantizer.block_sizes[-1],
280-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
281-
weight.device
282-
),
283-
)[0]
278+
else:
279+
return NVFP4QTensor.get_weights_scaling_factor(
280+
weight,
281+
weight_quantizer.block_sizes[-1],
282+
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
283+
weight.device
284+
),
285+
)[0]
284286

285287
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
286-
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
287-
1
288-
].reshape(*weight.shape[:-1], -1)
288+
if hasattr(weight_quantizer, "_scale"):
289+
return weight_quantizer._scale.reshape(*weight.shape[:-1], -1)
290+
else:
291+
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
292+
1
293+
].reshape(*weight.shape[:-1], -1)
289294
return get_scaling_factor(weight_quantizer)
290295

291296

@@ -301,7 +306,10 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
301306
QUANTIZATION_NVFP4_AWQ,
302307
QUANTIZATION_W4A8_NVFP4_FP8,
303308
]:
304-
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
309+
if hasattr(weight_quantizer, "_double_scale"):
310+
return weight_quantizer._double_scale
311+
else:
312+
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
305313

306314
# SequentialQuantizer is required
307315
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
@@ -818,7 +826,12 @@ def from_quantized_weight(
818826
raise NotImplementedError(f"quantization format {quantization} not supported")
819827

820828

821-
def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict:
829+
def postprocess_state_dict(
830+
state_dict: dict,
831+
maxbound: float,
832+
quantization: str | None,
833+
is_modelopt_trained_lora: bool = False,
834+
) -> dict:
822835
"""Filters out keys related to weight quantizers and updates KV cache related keys.
823836
824837
Args:
@@ -835,11 +848,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
835848
"k_bmm_quantizer._bias_value": "k_proj.k_bias",
836849
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
837850
"input_quantizer._pre_quant_scale": "pre_quant_scale",
838-
"base_layer.weight": "weight",
839-
"base_layer.input_scale": "input_scale",
840-
"base_layer.weight_scale": "weight_scale",
841851
}
842852

853+
# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
854+
if is_modelopt_trained_lora:
855+
replacements.update(
856+
{
857+
"base_layer.weight": "weight",
858+
"base_layer.input_scale": "input_scale",
859+
"base_layer.weight_scale": "weight_scale",
860+
}
861+
)
862+
843863
post_state_dict = {}
844864

845865
for key, value in state_dict.items():
@@ -902,10 +922,10 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
902922
keys_to_delete.append(key)
903923

904924
# remove LoRA adapters from state dict
905-
for key, value in post_state_dict.items():
906-
if "lora" in key and key not in keys_to_delete:
907-
keys_to_delete.append(key)
908-
925+
if is_modelopt_trained_lora:
926+
for key, value in post_state_dict.items():
927+
if "lora" in key and key not in keys_to_delete:
928+
keys_to_delete.append(key)
909929
# Check for tied weights and remove duplicates
910930
seen_tensors = {}
911931

modelopt/torch/export/unified_export_hf.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ def _is_enabled_quantizer(quantizer):
8585

8686
def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
8787
"""Group modules that take the same input and register shared parameters in module."""
88-
# Skip for LoRA finetuned models
89-
if hasattr(model, "base_model"):
90-
return
9188
# TODO: Handle DBRX MoE
9289
input_to_linear = defaultdict(list)
9390
output_to_layernorm = defaultdict(None)
@@ -339,7 +336,7 @@ def _export_quantized_weight(
339336

340337

341338
def _export_hf_checkpoint(
342-
model: nn.Module, dtype: torch.dtype | None = None
339+
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_trained_lora: bool = False
343340
) -> tuple[dict[str, Any], dict[str, Any]]:
344341
"""Exports the torch model to the packed checkpoint with original HF naming.
345342
@@ -431,7 +428,9 @@ def _export_hf_checkpoint(
431428

432429
# Resmooth and requantize fused layers
433430
# TODO: Handle mixed precision
434-
requantize_resmooth_fused_llm_layers(model)
431+
# TODO: Support requantize and resmooth for modelopt-trained LoRA models
432+
if not is_modelopt_trained_lora:
433+
requantize_resmooth_fused_llm_layers(model)
435434

436435
# Remove all hooks from the model
437436
try:
@@ -490,7 +489,7 @@ def _export_hf_checkpoint(
490489
quantized_state_dict = model.state_dict()
491490

492491
quantized_state_dict = postprocess_state_dict(
493-
quantized_state_dict, kv_cache_max_bound, kv_cache_format
492+
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_trained_lora
494493
)
495494

496495
# Check if any layers are quantized
@@ -505,6 +504,7 @@ def export_hf_checkpoint(
505504
dtype: torch.dtype | None = None,
506505
export_dir: Path | str = tempfile.gettempdir(),
507506
save_modelopt_state: bool = False,
507+
is_modelopt_trained_lora: bool = False,
508508
):
509509
"""Exports the torch model to unified checkpoint and saves to export_dir.
510510
@@ -514,15 +514,18 @@ def export_hf_checkpoint(
514514
export_dir: the target export path.
515515
save_modelopt_state: whether to save the modelopt state_dict.
516516
"""
517-
is_lora = hasattr(model, "base_model")
518-
base_export_dir: Path | str = f"{export_dir}/base_model" if is_lora else export_dir
517+
base_export_dir: Path | str = (
518+
f"{export_dir}/base_model" if is_modelopt_trained_lora else export_dir
519+
)
519520
export_dir = Path(export_dir)
520521
export_dir.mkdir(parents=True, exist_ok=True)
521522
base_export_dir = Path(base_export_dir)
522523
base_export_dir.mkdir(parents=True, exist_ok=True)
523524

524525
try:
525-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
526+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
527+
model, dtype, is_modelopt_trained_lora
528+
)
526529

527530
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
528531
# Save hf_quant_config.json for backward compatibility
@@ -534,11 +537,11 @@ def export_hf_checkpoint(
534537
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
535538

536539
# In the case of LoRA model, we save the base model
537-
if is_lora:
540+
if is_modelopt_trained_lora:
538541
model.base_model.save_pretrained(
539542
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
540543
)
541-
model.save_pretrained(export_dir, save_modelopt_state=save_modelopt_state)
544+
model.save_pretrained(export_dir)
542545
else:
543546
model.save_pretrained(
544547
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def _save_modelopt_state_with_weights(self):
185185
# Save base model compressed weights for QLoRA
186186
if getattr(self.quant_args, "compress", False):
187187
# Save base model config.json
188-
# weight_quantizer = self.quant_cfg["quant_cfg"]["*weight_quantizer"]
189188
self.model.config.save_pretrained(self.args.output_dir)
190189

191190
# Save base model compressed weights excluding lora weights
@@ -292,14 +291,14 @@ def save_model(self, *args, **kwargs):
292291
def _load_best_model(self, *args, **kwargs):
293292
"""Load the best model for final evaluation."""
294293
is_lora = getattr(self.args, "lora", None)
295-
if not is_lora:
296-
super()._load_best_model(*args, **kwargs)
297-
else:
294+
if is_lora and not self.is_fsdp_enabled:
298295
# Custom logic for loading best model with LoRA
299296
# TODO: Remove once we migrate to using get_peft_model()
300297
adapter_name = self.model.active_adapter()
301298
self.model.delete_adapter(adapter_name)
302299
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
300+
else:
301+
super()._load_best_model(*args, **kwargs)
303302

304303
def _patch_accelerate_for_fsdp2_fix(self):
305304
"""Fixes for accelerate prepare.

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,9 @@ def _unpack_tensor(input: torch.Tensor):
270270
return unpacked.reshape(unpacked_shape)
271271

272272
# Get scales from kwargs
273-
if kwarg["scale"].dtype == torch.uint8 and kwarg["scale"].ndim == 1:
274-
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
275-
try:
276-
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
277-
cutlass_fp4_scale_to_modelopt_fp4_scale,
278-
)
279-
280-
kwarg["scale"] = cutlass_fp4_scale_to_modelopt_fp4_scale(
281-
kwarg["scale"], self.metadata["shape"][-2:]
282-
)
283-
except ImportError as e:
284-
raise ImportError(
285-
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
286-
) from e
273+
kwarg["scale"] = self.get_modelopt_weights_scaling_factor(
274+
kwarg["scale"], self.metadata["shape"]
275+
)
287276

288277
if fast:
289278
from ..triton.fp4_kernel import fp4_dequantize

0 commit comments

Comments
 (0)