Skip to content

Commit e18c323

Browse files
committed
export for fp8 lora base model
Signed-off-by: Suguna Velury <[email protected]>
1 parent afba3a9 commit e18c323

File tree

5 files changed

+77
-23
lines changed

5 files changed

+77
-23
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import transformers
2323
from accelerate import infer_auto_device_map, init_empty_weights
2424
from accelerate.utils import get_max_memory
25+
from safetensors.torch import load_file
2526
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2627

28+
from modelopt.torch.opt.conversion import restore_from_modelopt_state
2729
from modelopt.torch.utils.image_processor import MllamaImageProcessor
2830

2931
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
@@ -122,6 +124,31 @@ def get_dtype(dtype):
122124
return dtype
123125

124126

127+
def get_lora_model(
128+
ckpt_path: str,
129+
device="cuda",
130+
):
131+
"""
132+
Loads a QLoRA model that has been trained using modelopt trainer.
133+
"""
134+
device_map = "auto"
135+
if device == "cpu":
136+
device_map = "cpu"
137+
138+
# Load model with adapters
139+
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
140+
141+
# Restore modelopt state
142+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state.pth", weights_only=False)
143+
restore_from_modelopt_state(model, modelopt_state)
144+
145+
# Load compressed weights
146+
state_dict = load_file(f"{ckpt_path}/model.safetensors")
147+
model.load_state_dict(state_dict, strict=False)
148+
149+
return model
150+
151+
125152
def get_model(
126153
ckpt_path,
127154
device="cuda",

examples/llm_ptq/hf_ptq.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
import numpy as np
2424
import torch
2525
from accelerate.hooks import remove_hook_from_module
26-
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
26+
from example_utils import (
27+
apply_kv_cache_quant,
28+
get_lora_model,
29+
get_model,
30+
get_processor,
31+
get_tokenizer,
32+
is_enc_dec,
33+
)
2734
from transformers import (
2835
AutoConfig,
2936
AutoModelForCausalLM,
@@ -231,14 +238,20 @@ def main(args):
231238
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
232239
calibration_only = False
233240
if not args.low_memory_mode:
234-
model = get_model(
235-
args.pyt_ckpt_path,
236-
args.device,
237-
gpu_mem_percentage=args.gpu_max_mem_percentage,
238-
trust_remote_code=args.trust_remote_code,
239-
use_seq_device_map=args.use_seq_device_map,
240-
attn_implementation=args.attn_implementation,
241-
)
241+
if args.lora:
242+
model = get_lora_model(
243+
args.pyt_ckpt_path,
244+
args.device,
245+
)
246+
else:
247+
model = get_model(
248+
args.pyt_ckpt_path,
249+
args.device,
250+
gpu_mem_percentage=args.gpu_max_mem_percentage,
251+
trust_remote_code=args.trust_remote_code,
252+
use_seq_device_map=args.use_seq_device_map,
253+
attn_implementation=args.attn_implementation,
254+
)
242255
else:
243256
assert args.qformat in QUANT_CFG_CHOICES, (
244257
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -615,6 +628,7 @@ def output_decode(generated_ids, input_shape):
615628
"They will be set at deployment time."
616629
)
617630

631+
print("DEBUG LOG: Calling unified export hf checkpoint")
618632
export_hf_checkpoint(
619633
full_model,
620634
export_dir=export_path,
@@ -755,6 +769,12 @@ def output_decode(generated_ids, input_shape):
755769
default=None,
756770
type=str,
757771
)
772+
parser.add_argument(
773+
"--lora",
774+
help="Specify the model to be exported is a LoRA model trained using modelopt.",
775+
default=False,
776+
action="store_true",
777+
)
758778

759779
args = parser.parse_args()
760780

modelopt/torch/export/quant_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,11 +1081,9 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
10811081
if block_size == 0:
10821082
block_size = get_weight_block_size(module)
10831083

1084-
# Handles case if default weight quantizer is not enabled or is None
1085-
if block_size != 0:
1086-
# Construct per layer config dictionary
1087-
layer_config_dict[name + ".quantization"] = quantization_format
1088-
layer_config_dict[name + ".awq_block_size"] = block_size
1084+
# Construct per layer config dictionary
1085+
layer_config_dict[name + ".quantization"] = quantization_format
1086+
layer_config_dict[name + ".awq_block_size"] = block_size
10891087

10901088
# Find kv cache quant format
10911089
if (

modelopt/torch/export/unified_export_hf.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -514,32 +514,40 @@ def export_hf_checkpoint(
514514
export_dir: the target export path.
515515
save_modelopt_state: whether to save the modelopt state_dict.
516516
"""
517+
is_lora = hasattr(model, "base_model")
518+
base_export_dir: Path | str = f"{export_dir}/base_model" if is_lora else export_dir
517519
export_dir = Path(export_dir)
518520
export_dir.mkdir(parents=True, exist_ok=True)
521+
base_export_dir = Path(base_export_dir)
522+
base_export_dir.mkdir(parents=True, exist_ok=True)
523+
519524
try:
520525
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
521526

522527
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
523528
# Save hf_quant_config.json for backward compatibility
524-
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
529+
with open(f"{base_export_dir}/hf_quant_config.json", "w") as file:
525530
json.dump(hf_quant_config, file, indent=4)
526531

527532
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
528533

529534
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
530535

531-
# For QLoRA models we export the base model
532-
if hasattr(model, "base_model"):
533-
model = model.base_model
536+
# In the case of LoRA model, we save the base model
537+
if is_lora:
538+
model.base_model.save_pretrained(
539+
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
540+
)
541+
534542
model.save_pretrained(
535543
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
536544
)
537545

538-
original_config = f"{export_dir}/config.json"
546+
original_config = f"{base_export_dir}/config.json"
539547
config_data = {}
540548

541-
with open(original_config) as file:
542-
config_data = json.load(file)
549+
# In the case of LoRA model.save_pretrained does not save the correct config.json
550+
config_data = model.config.to_dict()
543551

544552
config_data["quantization_config"] = hf_quant_config
545553

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self.model, "peft_config"
148148
):
149149
# TODO: use get_peft_model here instead of add_adapter
150-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
150+
self.model.add_adapter(self.args.lora_config)
151151
print_rank_0("Lora adapter added.")
152152

153153
if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
@@ -185,6 +185,7 @@ def _save_modelopt_state_with_weights(self):
185185
# Save base model compressed weights for QLoRA
186186
if getattr(self.quant_args, "compress", False):
187187
# Save base model config.json
188+
# weight_quantizer = self.quant_cfg["quant_cfg"]["*weight_quantizer"]
188189
self.model.config.save_pretrained(self.args.output_dir)
189190

190191
# Save base model compressed weights excluding lora weights
@@ -362,7 +363,7 @@ def __init__(
362363
if self.quant_cfg is not None and not is_quantized(self.model):
363364
self._quantize_model()
364365
if getattr(self.args, "lora_config", None) is not None:
365-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
366+
self.model.add_adapter(self.args.lora_config)
366367
print_rank_0("Lora adapter added.")
367368
self._convert_to_distillation_model()
368369

0 commit comments

Comments
 (0)