Skip to content

Commit 40202eb

Browse files
committed
e2e checkpoint tested for nvfp4 and fp8
Signed-off-by: Suguna Velury <[email protected]>
1 parent 7310346 commit 40202eb

File tree

6 files changed

+83
-73
lines changed

6 files changed

+83
-73
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,11 @@ def get_dtype(dtype):
118118

119119
def get_lora_model(
120120
ckpt_path: str,
121-
device="cuda",
121+
device_map="cuda",
122122
):
123123
"""
124124
Loads a QLoRA model that has been trained using modelopt trainer.
125125
"""
126-
device_map = "auto"
127-
if device == "cpu":
128-
device_map = "cpu"
129-
130126
# Load model with adapters
131127
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
132128

@@ -148,13 +144,18 @@ def get_model(
148144
trust_remote_code=False,
149145
use_seq_device_map=False,
150146
attn_implementation=None,
147+
is_lora=False,
151148
):
152149
print(f"Initializing model from {ckpt_path}")
153150

154151
device_map = "auto"
155152
if device == "cpu":
156153
device_map = "cpu"
157154

155+
if is_lora:
156+
model = get_lora_model(ckpt_path, device_map)
157+
return model
158+
158159
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
159160
if attn_implementation is not None:
160161
config_kwargs["attn_implementation"] = attn_implementation

examples/llm_ptq/hf_ptq.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,15 @@ def main(args):
241241
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
242242
calibration_only = False
243243
if not args.low_memory_mode:
244-
if args.lora:
245-
model = get_lora_model(
246-
args.pyt_ckpt_path,
247-
args.device,
248-
)
249-
else:
250-
model = get_model(
251-
args.pyt_ckpt_path,
252-
args.device,
253-
gpu_mem_percentage=args.gpu_max_mem_percentage,
254-
trust_remote_code=args.trust_remote_code,
255-
use_seq_device_map=args.use_seq_device_map,
256-
attn_implementation=args.attn_implementation,
257-
)
244+
model = get_model(
245+
args.pyt_ckpt_path,
246+
args.device,
247+
gpu_mem_percentage=args.gpu_max_mem_percentage,
248+
trust_remote_code=args.trust_remote_code,
249+
use_seq_device_map=args.use_seq_device_map,
250+
attn_implementation=args.attn_implementation,
251+
is_lora=args.lora,
252+
)
258253
else:
259254
assert args.qformat in QUANT_CFG_CHOICES, (
260255
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -395,9 +390,6 @@ def main(args):
395390
sample_input_single_batch = None
396391

397392
run_auto_quant = args.auto_quantize_bits is not None
398-
print("DEBUG LOG: Entereing here")
399-
for k, v in model.state_dict().items():
400-
print(k, v.shape, v.dtype, v.device)
401393

402394
args.batch_size = get_max_batch_size(
403395
model,
@@ -493,7 +485,7 @@ def main(args):
493485
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
494486
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
495487

496-
if not model_is_already_quantized or calibration_only:
488+
if calibration_only:
497489
# Only run single sample for preview
498490
input_ids = next(iter(calib_dataloader))[
499491
"input_features" if model_type == "whisper" else "input_ids"
@@ -567,7 +559,12 @@ def output_decode(generated_ids, input_shape):
567559

568560
else:
569561
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
570-
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
562+
if model_is_already_quantized:
563+
warnings.warn(
564+
"Skipping quantization: Model is already quantized. Exporting the model..."
565+
)
566+
else:
567+
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
571568

572569
with torch.inference_mode():
573570
if model_type is None:
@@ -643,6 +640,7 @@ def output_decode(generated_ids, input_shape):
643640
export_hf_checkpoint(
644641
full_model,
645642
export_dir=export_path,
643+
is_modelopt_trained_lora=args.lora,
646644
)
647645

648646
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

modelopt/torch/export/quant_utils.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,23 +270,28 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270
QUANTIZATION_NVFP4_AWQ,
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273+
# If scale is already registered, indicates weights are already compressed.
274+
# We convert to modelopt scale if necessary and return
273275
if hasattr(weight_quantizer, "_scale"):
274276
return NVFP4QTensor.get_modelopt_weights_scaling_factor(
275277
weight_quantizer._scale, weight.metadata["shape"]
276278
)
277-
278-
return NVFP4QTensor.get_weights_scaling_factor(
279-
weight,
280-
weight_quantizer.block_sizes[-1],
281-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
282-
weight.device
283-
),
284-
)[0]
279+
else:
280+
return NVFP4QTensor.get_weights_scaling_factor(
281+
weight,
282+
weight_quantizer.block_sizes[-1],
283+
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
284+
weight.device
285+
),
286+
)[0]
285287

286288
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
287-
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
288-
1
289-
].reshape(*weight.shape[:-1], -1)
289+
if hasattr(weight_quantizer, "_scale"):
290+
return weight_quantizer._scale.reshape(*weight.shape[:-1], -1)
291+
else:
292+
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
293+
1
294+
].reshape(*weight.shape[:-1], -1)
290295
return get_scaling_factor(weight_quantizer)
291296

292297

@@ -302,7 +307,10 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
302307
QUANTIZATION_NVFP4_AWQ,
303308
QUANTIZATION_W4A8_NVFP4_FP8,
304309
]:
305-
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
310+
if hasattr(weight_quantizer, "_double_scale"):
311+
return weight_quantizer._double_scale
312+
else:
313+
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
306314

307315
# SequentialQuantizer is required
308316
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
@@ -824,7 +832,12 @@ def from_quantized_weight(
824832
raise NotImplementedError(f"quantization format {quantization} not supported")
825833

826834

827-
def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str | None) -> dict:
835+
def postprocess_state_dict(
836+
state_dict: dict,
837+
maxbound: float,
838+
quantization: str | None,
839+
is_modelopt_trained_lora: bool = False,
840+
) -> dict:
828841
"""Filters out keys related to weight quantizers and updates KV cache related keys.
829842
830843
Args:
@@ -841,11 +854,18 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
841854
"k_bmm_quantizer._bias_value": "k_proj.k_bias",
842855
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
843856
"input_quantizer._pre_quant_scale": "pre_quant_scale",
844-
"base_layer.weight": "weight",
845-
"base_layer.input_scale": "input_scale",
846-
"base_layer.weight_scale": "weight_scale",
847857
}
848858

859+
# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
860+
if is_modelopt_trained_lora:
861+
replacements.update(
862+
{
863+
"base_layer.weight": "weight",
864+
"base_layer.input_scale": "input_scale",
865+
"base_layer.weight_scale": "weight_scale",
866+
}
867+
)
868+
849869
post_state_dict = {}
850870

851871
for key, value in state_dict.items():
@@ -908,10 +928,10 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
908928
keys_to_delete.append(key)
909929

910930
# remove LoRA adapters from state dict
911-
for key, value in post_state_dict.items():
912-
if "lora" in key and key not in keys_to_delete:
913-
keys_to_delete.append(key)
914-
931+
if is_modelopt_trained_lora:
932+
for key, value in post_state_dict.items():
933+
if "lora" in key and key not in keys_to_delete:
934+
keys_to_delete.append(key)
915935
# Check for tied weights and remove duplicates
916936
seen_tensors = {}
917937

modelopt/torch/export/unified_export_hf.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ def _is_enabled_quantizer(quantizer):
8585

8686
def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
8787
"""Group modules that take the same input and register shared parameters in module."""
88-
# Skip for LoRA finetuned models
89-
if hasattr(model, "base_model"):
90-
return
9188
# TODO: Handle DBRX MoE
9289
input_to_linear = defaultdict(list)
9390
output_to_layernorm = defaultdict(None)
@@ -343,7 +340,7 @@ def _export_quantized_weight(
343340

344341

345342
def _export_hf_checkpoint(
346-
model: nn.Module, dtype: torch.dtype | None = None
343+
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_trained_lora: bool = False
347344
) -> tuple[dict[str, Any], dict[str, Any]]:
348345
"""Exports the torch model to the packed checkpoint with original HF naming.
349346
@@ -435,7 +432,9 @@ def _export_hf_checkpoint(
435432

436433
# Resmooth and requantize fused layers
437434
# TODO: Handle mixed precision
438-
requantize_resmooth_fused_llm_layers(model)
435+
# TODO: Support requantize and resmooth for modelopt-trained LoRA models
436+
if not is_modelopt_trained_lora:
437+
requantize_resmooth_fused_llm_layers(model)
439438

440439
# Remove all hooks from the model
441440
try:
@@ -494,7 +493,7 @@ def _export_hf_checkpoint(
494493
quantized_state_dict = model.state_dict()
495494

496495
quantized_state_dict = postprocess_state_dict(
497-
quantized_state_dict, kv_cache_max_bound, kv_cache_format
496+
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_trained_lora
498497
)
499498

500499
# Check if any layers are quantized
@@ -509,6 +508,7 @@ def export_hf_checkpoint(
509508
dtype: torch.dtype | None = None,
510509
export_dir: Path | str = tempfile.gettempdir(),
511510
save_modelopt_state: bool = False,
511+
is_modelopt_trained_lora: bool = False,
512512
):
513513
"""Exports the torch model to unified checkpoint and saves to export_dir.
514514
@@ -518,15 +518,18 @@ def export_hf_checkpoint(
518518
export_dir: the target export path.
519519
save_modelopt_state: whether to save the modelopt state_dict.
520520
"""
521-
is_lora = hasattr(model, "base_model")
522-
base_export_dir: Path | str = f"{export_dir}/base_model" if is_lora else export_dir
521+
base_export_dir: Path | str = (
522+
f"{export_dir}/base_model" if is_modelopt_trained_lora else export_dir
523+
)
523524
export_dir = Path(export_dir)
524525
export_dir.mkdir(parents=True, exist_ok=True)
525526
base_export_dir = Path(base_export_dir)
526527
base_export_dir.mkdir(parents=True, exist_ok=True)
527528

528529
try:
529-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
530+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
531+
model, dtype, is_modelopt_trained_lora
532+
)
530533

531534
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
532535
# Save hf_quant_config.json for backward compatibility
@@ -538,11 +541,11 @@ def export_hf_checkpoint(
538541
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
539542

540543
# In the case of LoRA model, we save the base model
541-
if is_lora:
544+
if is_modelopt_trained_lora:
542545
model.base_model.save_pretrained(
543546
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
544547
)
545-
model.save_pretrained(export_dir, save_modelopt_state=save_modelopt_state)
548+
model.save_pretrained(export_dir)
546549
else:
547550
model.save_pretrained(
548551
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def _save_modelopt_state_with_weights(self):
185185
# Save base model compressed weights for QLoRA
186186
if getattr(self.quant_args, "compress", False):
187187
# Save base model config.json
188-
# weight_quantizer = self.quant_cfg["quant_cfg"]["*weight_quantizer"]
189188
self.model.config.save_pretrained(self.args.output_dir)
190189

191190
# Save base model compressed weights excluding lora weights
@@ -292,14 +291,14 @@ def save_model(self, *args, **kwargs):
292291
def _load_best_model(self, *args, **kwargs):
293292
"""Load the best model for final evaluation."""
294293
is_lora = getattr(self.args, "lora", None)
295-
if not is_lora:
296-
super()._load_best_model(*args, **kwargs)
297-
else:
294+
if is_lora and not self.is_fsdp_enabled:
298295
# Custom logic for loading best model with LoRA
299296
# TODO: Remove once we migrate to using get_peft_model()
300297
adapter_name = self.model.active_adapter()
301298
self.model.delete_adapter(adapter_name)
302299
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
300+
else:
301+
super()._load_best_model(*args, **kwargs)
303302

304303
def _patch_accelerate_for_fsdp2_fix(self):
305304
"""Fixes for accelerate prepare.

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,9 @@ def _unpack_tensor(input: torch.Tensor):
270270
return unpacked.reshape(unpacked_shape)
271271

272272
# Get scales from kwargs
273-
if kwarg["scale"].dtype == torch.uint8 and kwarg["scale"].ndim == 1:
274-
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
275-
try:
276-
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
277-
cutlass_fp4_scale_to_modelopt_fp4_scale,
278-
)
279-
280-
kwarg["scale"] = cutlass_fp4_scale_to_modelopt_fp4_scale(
281-
kwarg["scale"], self.metadata["shape"][-2:]
282-
)
283-
except ImportError as e:
284-
raise ImportError(
285-
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
286-
) from e
273+
kwarg["scale"] = self.get_modelopt_weights_scaling_factor(
274+
kwarg["scale"], self.metadata["shape"]
275+
)
287276

288277
if fast:
289278
from ..triton.fp4_kernel import fp4_dequantize

0 commit comments

Comments
 (0)