Skip to content

Commit 387ea68

Browse files
committed
refactored
Signed-off-by: Suguna Velury <[email protected]>
1 parent d23963b commit 387ea68

File tree

6 files changed

+157
-110
lines changed

6 files changed

+157
-110
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
import transformers
2323
from accelerate import infer_auto_device_map, init_empty_weights
2424
from accelerate.utils import get_max_memory
25-
from safetensors.torch import load_file
2625
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2726

28-
from modelopt.torch.opt.conversion import restore_from_modelopt_state
2927
from modelopt.torch.utils.image_processor import MllamaImageProcessor
3028

3129
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
@@ -124,46 +122,20 @@ def get_dtype(dtype):
124122
return dtype
125123

126124

127-
def get_lora_model(
128-
ckpt_path: str,
129-
device_map="cuda",
130-
):
131-
"""
132-
Loads a QLoRA model that has been trained using modelopt trainer.
133-
"""
134-
# Load model with adapters
135-
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
136-
137-
# Restore modelopt state
138-
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state.pth", weights_only=False)
139-
restore_from_modelopt_state(model, modelopt_state)
140-
141-
# Load compressed weights
142-
state_dict = load_file(f"{ckpt_path}/model.safetensors")
143-
model.load_state_dict(state_dict, strict=False)
144-
145-
return model
146-
147-
148125
def get_model(
149126
ckpt_path,
150127
device="cuda",
151128
gpu_mem_percentage=0.8,
152129
trust_remote_code=False,
153130
use_seq_device_map=False,
154131
attn_implementation=None,
155-
is_modelopt_qlora=False,
156132
):
157133
print(f"Initializing model from {ckpt_path}")
158134

159135
device_map = "auto"
160136
if device == "cpu":
161137
device_map = "cpu"
162138

163-
if is_modelopt_qlora:
164-
model = get_lora_model(ckpt_path, device_map)
165-
return model
166-
167139
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
168140
if attn_implementation is not None:
169141
config_kwargs["attn_implementation"] = attn_implementation

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def main(args):
238238
trust_remote_code=args.trust_remote_code,
239239
use_seq_device_map=args.use_seq_device_map,
240240
attn_implementation=args.attn_implementation,
241-
is_modelopt_qlora=args.qlora,
242241
)
243242
else:
244243
assert args.qformat in QUANT_CFG_CHOICES, (
@@ -345,9 +344,7 @@ def main(args):
345344
)
346345
mts.export(model)
347346

348-
if (
349-
args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES
350-
) and not model_is_already_quantized:
347+
if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES:
351348
if "awq" in args.qformat:
352349
print(
353350
"\n####\nAWQ calibration could take longer than other calibration methods. "
@@ -474,7 +471,7 @@ def main(args):
474471
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
475472
)
476473

477-
if calibration_only:
474+
if not model_is_already_quantized and calibration_only:
478475
# Only run single sample for preview
479476
input_ids = next(iter(calib_dataloader))[
480477
"input_features" if model_type == "whisper" else "input_ids"
@@ -548,12 +545,7 @@ def output_decode(generated_ids, input_shape):
548545

549546
else:
550547
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
551-
if model_is_already_quantized:
552-
warnings.warn(
553-
"Skipping quantization: Model is already quantized. Exporting the model..."
554-
)
555-
else:
556-
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
548+
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
557549

558550
with torch.inference_mode():
559551
if model_type is None:
@@ -626,7 +618,6 @@ def output_decode(generated_ids, input_shape):
626618
export_hf_checkpoint(
627619
full_model,
628620
export_dir=export_path,
629-
is_modelopt_qlora=args.qlora,
630621
)
631622

632623
# Restore default padding and export the tokenizer as well.

examples/llm_qat/export.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import json
18+
import warnings
19+
from pathlib import Path
20+
21+
import torch
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
25+
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
26+
from modelopt.torch.opt.conversion import restore_from_modelopt_state
27+
from modelopt.torch.quantization.utils import set_quantizer_state_dict
28+
29+
RAND_SEED = 1234
30+
31+
32+
def get_lora_model(
33+
ckpt_path: str,
34+
device="cuda",
35+
):
36+
"""
37+
Loads a QLoRA model that has been trained using modelopt trainer.
38+
"""
39+
device_map = "auto"
40+
if device == "cpu":
41+
device_map = "cpu"
42+
43+
# Load model with adapters
44+
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
45+
46+
# Restore modelopt state
47+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False)
48+
restore_from_modelopt_state(model, modelopt_state)
49+
50+
# Restore modelopt quantizer state dict
51+
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
52+
if modelopt_weights is not None:
53+
print("Restoring modelopt weights")
54+
set_quantizer_state_dict(model, modelopt_weights)
55+
56+
return model
57+
58+
59+
def main(args):
60+
# Load model
61+
model = get_lora_model(args.pyt_ckpt_path, args.device)
62+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
63+
64+
# Export HF checkpoint
65+
export_dir = Path(args.export_path)
66+
export_dir.mkdir(parents=True, exist_ok=True)
67+
base_model_dir = export_dir / "base_model"
68+
base_model_dir.mkdir(parents=True, exist_ok=True)
69+
70+
try:
71+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_lora=True)
72+
73+
with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file:
74+
json.dump(hf_quant_config, file, indent=4)
75+
76+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
77+
78+
# Save base model
79+
model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict)
80+
# Save adapters
81+
model.save_pretrained(export_dir)
82+
83+
config_path = f"{export_dir}/base_model/config.json"
84+
85+
# In the case of LoRA model.save_pretrained does not save the correct config.json
86+
config_data = model.config.to_dict()
87+
print(config_data)
88+
89+
config_data["quantization_config"] = hf_quant_config
90+
91+
with open(config_path, "w") as file:
92+
json.dump(config_data, file, indent=4)
93+
94+
# Save tokenizer
95+
tokenizer.save_pretrained(export_dir)
96+
97+
except Exception as e:
98+
warnings.warn(
99+
"Cannot export model to the model_config. The modelopt-optimized model state_dict"
100+
" can be saved with torch.save for further inspection."
101+
)
102+
raise e
103+
104+
105+
if __name__ == "__main__":
106+
parser = argparse.ArgumentParser(description=__doc__)
107+
parser.add_argument(
108+
"--pyt_ckpt_path",
109+
help="Specify where the PyTorch checkpoint path is",
110+
required=True,
111+
)
112+
113+
parser.add_argument("--device", default="cuda")
114+
115+
parser.add_argument("--export_path", default="exported_model")
116+
117+
args = parser.parse_args()
118+
119+
main(args)

modelopt/torch/export/quant_utils.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -269,28 +269,18 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
269269
QUANTIZATION_NVFP4_AWQ,
270270
QUANTIZATION_W4A8_NVFP4_FP8,
271271
]:
272-
# If scale is already registered, indicates weights are already compressed.
273-
# We convert to modelopt scale if necessary and return
274-
if hasattr(weight_quantizer, "_scale"):
275-
return NVFP4QTensor.get_modelopt_weights_scaling_factor(
276-
weight_quantizer._scale, weight.metadata["shape"]
277-
)
278-
else:
279-
return NVFP4QTensor.get_weights_scaling_factor(
280-
weight,
281-
weight_quantizer.block_sizes[-1],
282-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
283-
weight.device
284-
),
285-
)[0]
272+
return NVFP4QTensor.get_weights_scaling_factor(
273+
weight,
274+
weight_quantizer.block_sizes[-1],
275+
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
276+
weight.device
277+
),
278+
)[0]
286279

287280
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
288-
if hasattr(weight_quantizer, "_scale"):
289-
return weight_quantizer._scale.reshape(*weight.shape[:-1], -1)
290-
else:
291-
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
292-
1
293-
].reshape(*weight.shape[:-1], -1)
281+
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
282+
1
283+
].reshape(*weight.shape[:-1], -1)
294284
return get_scaling_factor(weight_quantizer)
295285

296286

@@ -306,10 +296,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
306296
QUANTIZATION_NVFP4_AWQ,
307297
QUANTIZATION_W4A8_NVFP4_FP8,
308298
]:
309-
if hasattr(weight_quantizer, "_double_scale"):
310-
return weight_quantizer._double_scale
311-
else:
312-
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
299+
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
313300

314301
# SequentialQuantizer is required
315302
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
@@ -740,7 +727,6 @@ def to_quantized_weight(
740727
quantization: str,
741728
weights_scaling_factor2: torch.Tensor | None = None,
742729
block_size: int | None = None,
743-
dtype: torch.dtype | None = None,
744730
):
745731
"""Converts the weight to the quantized (packed) format."""
746732
if weights_scaling_factor is not None:
@@ -753,9 +739,6 @@ def to_quantized_weight(
753739
if isinstance(weight, QTensorWrapper):
754740
return weight.data
755741

756-
if dtype:
757-
weight = weight.to(dtype)
758-
759742
if quantization == QUANTIZATION_FP8:
760743
# Fix RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float
761744
# in speculative decoding fp8 model export

0 commit comments

Comments
 (0)