Skip to content

Commit 3b7ef44

Browse files
committed
refactored
Signed-off-by: Suguna Velury <[email protected]>
1 parent f1a2ff3 commit 3b7ef44

File tree

6 files changed

+157
-109
lines changed

6 files changed

+157
-109
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import transformers
2626
from accelerate import infer_auto_device_map, init_empty_weights
2727
from accelerate.utils import get_max_memory
28-
from safetensors.torch import load_file
2928
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
3029

3130
try:
@@ -116,46 +115,20 @@ def get_dtype(dtype):
116115
return dtype
117116

118117

119-
def get_lora_model(
120-
ckpt_path: str,
121-
device_map="cuda",
122-
):
123-
"""
124-
Loads a QLoRA model that has been trained using modelopt trainer.
125-
"""
126-
# Load model with adapters
127-
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
128-
129-
# Restore modelopt state
130-
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state.pth", weights_only=False)
131-
restore_from_modelopt_state(model, modelopt_state)
132-
133-
# Load compressed weights
134-
state_dict = load_file(f"{ckpt_path}/model.safetensors")
135-
model.load_state_dict(state_dict, strict=False)
136-
137-
return model
138-
139-
140118
def get_model(
141119
ckpt_path,
142120
device="cuda",
143121
gpu_mem_percentage=0.8,
144122
trust_remote_code=False,
145123
use_seq_device_map=False,
146124
attn_implementation=None,
147-
is_modelopt_qlora=False,
148125
):
149126
print(f"Initializing model from {ckpt_path}")
150127

151128
device_map = "auto"
152129
if device == "cpu":
153130
device_map = "cpu"
154131

155-
if is_modelopt_qlora:
156-
model = get_lora_model(ckpt_path, device_map)
157-
return model
158-
159132
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
160133
if attn_implementation is not None:
161134
config_kwargs["attn_implementation"] = attn_implementation

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def main(args):
248248
trust_remote_code=args.trust_remote_code,
249249
use_seq_device_map=args.use_seq_device_map,
250250
attn_implementation=args.attn_implementation,
251-
is_modelopt_qlora=args.qlora,
252251
)
253252
else:
254253
assert args.qformat in QUANT_CFG_CHOICES, (
@@ -359,9 +358,7 @@ def main(args):
359358
)
360359
mts.export(model)
361360

362-
if (
363-
args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES
364-
) and not model_is_already_quantized:
361+
if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES:
365362
if "awq" in args.qformat:
366363
print(
367364
"\n####\nAWQ calibration could take longer than other calibration methods. "
@@ -485,7 +482,7 @@ def main(args):
485482
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
486483
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
487484

488-
if calibration_only:
485+
if not model_is_already_quantized and calibration_only:
489486
# Only run single sample for preview
490487
input_ids = next(iter(calib_dataloader))[
491488
"input_features" if model_type == "whisper" else "input_ids"
@@ -559,12 +556,7 @@ def output_decode(generated_ids, input_shape):
559556

560557
else:
561558
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
562-
if model_is_already_quantized:
563-
warnings.warn(
564-
"Skipping quantization: Model is already quantized. Exporting the model..."
565-
)
566-
else:
567-
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
559+
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
568560

569561
with torch.inference_mode():
570562
if model_type is None:
@@ -640,7 +632,6 @@ def output_decode(generated_ids, input_shape):
640632
export_hf_checkpoint(
641633
full_model,
642634
export_dir=export_path,
643-
is_modelopt_qlora=args.qlora,
644635
)
645636

646637
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

examples/llm_qat/export.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import json
18+
import warnings
19+
from pathlib import Path
20+
21+
import torch
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
25+
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
26+
from modelopt.torch.opt.conversion import restore_from_modelopt_state
27+
from modelopt.torch.quantization.utils import set_quantizer_state_dict
28+
29+
RAND_SEED = 1234
30+
31+
32+
def get_lora_model(
33+
ckpt_path: str,
34+
device="cuda",
35+
):
36+
"""
37+
Loads a QLoRA model that has been trained using modelopt trainer.
38+
"""
39+
device_map = "auto"
40+
if device == "cpu":
41+
device_map = "cpu"
42+
43+
# Load model with adapters
44+
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
45+
46+
# Restore modelopt state
47+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False)
48+
restore_from_modelopt_state(model, modelopt_state)
49+
50+
# Restore modelopt quantizer state dict
51+
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
52+
if modelopt_weights is not None:
53+
print("Restoring modelopt weights")
54+
set_quantizer_state_dict(model, modelopt_weights)
55+
56+
return model
57+
58+
59+
def main(args):
60+
# Load model
61+
model = get_lora_model(args.pyt_ckpt_path, args.device)
62+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
63+
64+
# Export HF checkpoint
65+
export_dir = Path(args.export_path)
66+
export_dir.mkdir(parents=True, exist_ok=True)
67+
base_model_dir = export_dir / "base_model"
68+
base_model_dir.mkdir(parents=True, exist_ok=True)
69+
70+
try:
71+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_lora=True)
72+
73+
with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file:
74+
json.dump(hf_quant_config, file, indent=4)
75+
76+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
77+
78+
# Save base model
79+
model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict)
80+
# Save adapters
81+
model.save_pretrained(export_dir)
82+
83+
config_path = f"{export_dir}/base_model/config.json"
84+
85+
# In the case of LoRA model.save_pretrained does not save the correct config.json
86+
config_data = model.config.to_dict()
87+
print(config_data)
88+
89+
config_data["quantization_config"] = hf_quant_config
90+
91+
with open(config_path, "w") as file:
92+
json.dump(config_data, file, indent=4)
93+
94+
# Save tokenizer
95+
tokenizer.save_pretrained(export_dir)
96+
97+
except Exception as e:
98+
warnings.warn(
99+
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
100+
" can be saved with torch.save for further inspection."
101+
)
102+
raise e
103+
104+
105+
if __name__ == "__main__":
106+
parser = argparse.ArgumentParser(description=__doc__)
107+
parser.add_argument(
108+
"--pyt_ckpt_path",
109+
help="Specify where the PyTorch checkpoint path is",
110+
required=True,
111+
)
112+
113+
parser.add_argument("--device", default="cuda")
114+
115+
parser.add_argument("--export_path", default="exported_model")
116+
117+
args = parser.parse_args()
118+
119+
main(args)

modelopt/torch/export/quant_utils.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -270,28 +270,18 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270
QUANTIZATION_NVFP4_AWQ,
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273-
# If scale is already registered, indicates weights are already compressed.
274-
# We convert to modelopt scale if necessary and return
275-
if hasattr(weight_quantizer, "_scale"):
276-
return NVFP4QTensor.get_modelopt_weights_scaling_factor(
277-
weight_quantizer._scale, weight.metadata["shape"]
278-
)
279-
else:
280-
return NVFP4QTensor.get_weights_scaling_factor(
281-
weight,
282-
weight_quantizer.block_sizes[-1],
283-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
284-
weight.device
285-
),
286-
)[0]
273+
return NVFP4QTensor.get_weights_scaling_factor(
274+
weight,
275+
weight_quantizer.block_sizes[-1],
276+
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
277+
weight.device
278+
),
279+
)[0]
287280

288281
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
289-
if hasattr(weight_quantizer, "_scale"):
290-
return weight_quantizer._scale.reshape(*weight.shape[:-1], -1)
291-
else:
292-
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
293-
1
294-
].reshape(*weight.shape[:-1], -1)
282+
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
283+
1
284+
].reshape(*weight.shape[:-1], -1)
295285
return get_scaling_factor(weight_quantizer)
296286

297287

@@ -307,10 +297,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
307297
QUANTIZATION_NVFP4_AWQ,
308298
QUANTIZATION_W4A8_NVFP4_FP8,
309299
]:
310-
if hasattr(weight_quantizer, "_double_scale"):
311-
return weight_quantizer._double_scale
312-
else:
313-
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
300+
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
314301

315302
# SequentialQuantizer is required
316303
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
@@ -746,7 +733,6 @@ def to_quantized_weight(
746733
quantization: str,
747734
weights_scaling_factor2: torch.Tensor | None = None,
748735
block_size: int | None = None,
749-
dtype: torch.dtype | None = None,
750736
):
751737
"""Converts the weight to the quantized (packed) format."""
752738
if weights_scaling_factor is not None:
@@ -759,9 +745,6 @@ def to_quantized_weight(
759745
if isinstance(weight, QTensorWrapper):
760746
return weight.data
761747

762-
if dtype:
763-
weight = weight.to(dtype)
764-
765748
if quantization == QUANTIZATION_FP8:
766749
# Fix RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float
767750
# in speculative decoding fp8 model export

0 commit comments

Comments
 (0)