Skip to content

Commit 3cf89cf

Browse files
committed
export for fp8 lora base model
Signed-off-by: Suguna Velury <[email protected]>
1 parent 70200f1 commit 3cf89cf

File tree

5 files changed

+68
-22
lines changed

5 files changed

+68
-22
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import transformers
2626
from accelerate import infer_auto_device_map, init_empty_weights
2727
from accelerate.utils import get_max_memory
28+
from safetensors.torch import load_file
2829
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2930

3031
try:
@@ -115,6 +116,31 @@ def get_dtype(dtype):
115116
return dtype
116117

117118

119+
def get_lora_model(
120+
ckpt_path: str,
121+
device="cuda",
122+
):
123+
"""
124+
Loads a QLoRA model that has been trained using modelopt trainer.
125+
"""
126+
device_map = "auto"
127+
if device == "cpu":
128+
device_map = "cpu"
129+
130+
# Load model with adapters
131+
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map)
132+
133+
# Restore modelopt state
134+
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state.pth", weights_only=False)
135+
restore_from_modelopt_state(model, modelopt_state)
136+
137+
# Load compressed weights
138+
state_dict = load_file(f"{ckpt_path}/model.safetensors")
139+
model.load_state_dict(state_dict, strict=False)
140+
141+
return model
142+
143+
118144
def get_model(
119145
ckpt_path,
120146
device="cuda",

examples/llm_ptq/hf_ptq.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,20 @@ def main(args):
241241
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
242242
calibration_only = False
243243
if not args.low_memory_mode:
244-
model = get_model(
245-
args.pyt_ckpt_path,
246-
args.device,
247-
gpu_mem_percentage=args.gpu_max_mem_percentage,
248-
trust_remote_code=args.trust_remote_code,
249-
use_seq_device_map=args.use_seq_device_map,
250-
attn_implementation=args.attn_implementation,
251-
)
244+
if args.lora:
245+
model = get_lora_model(
246+
args.pyt_ckpt_path,
247+
args.device,
248+
)
249+
else:
250+
model = get_model(
251+
args.pyt_ckpt_path,
252+
args.device,
253+
gpu_mem_percentage=args.gpu_max_mem_percentage,
254+
trust_remote_code=args.trust_remote_code,
255+
use_seq_device_map=args.use_seq_device_map,
256+
attn_implementation=args.attn_implementation,
257+
)
252258
else:
253259
assert args.qformat in QUANT_CFG_CHOICES, (
254260
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -629,6 +635,7 @@ def output_decode(generated_ids, input_shape):
629635
"They will be set at deployment time."
630636
)
631637

638+
print("DEBUG LOG: Calling unified export hf checkpoint")
632639
export_hf_checkpoint(
633640
full_model,
634641
export_dir=export_path,
@@ -772,6 +779,12 @@ def output_decode(generated_ids, input_shape):
772779
default=None,
773780
type=str,
774781
)
782+
parser.add_argument(
783+
"--lora",
784+
help="Specify the model to be exported is a LoRA model trained using modelopt.",
785+
default=False,
786+
action="store_true",
787+
)
775788

776789
args = parser.parse_args()
777790

modelopt/torch/export/quant_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,9 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
10871087
if block_size == 0:
10881088
block_size = get_weight_block_size(module)
10891089

1090-
# Handles case if default weight quantizer is not enabled or is None
1091-
if block_size != 0:
1092-
# Construct per layer config dictionary
1093-
layer_config_dict[name + ".quantization"] = quantization_format
1094-
layer_config_dict[name + ".awq_block_size"] = block_size
1090+
# Construct per layer config dictionary
1091+
layer_config_dict[name + ".quantization"] = quantization_format
1092+
layer_config_dict[name + ".awq_block_size"] = block_size
10951093

10961094
# Find kv cache quant format
10971095
if (

modelopt/torch/export/unified_export_hf.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -518,32 +518,40 @@ def export_hf_checkpoint(
518518
export_dir: the target export path.
519519
save_modelopt_state: whether to save the modelopt state_dict.
520520
"""
521+
is_lora = hasattr(model, "base_model")
522+
base_export_dir: Path | str = f"{export_dir}/base_model" if is_lora else export_dir
521523
export_dir = Path(export_dir)
522524
export_dir.mkdir(parents=True, exist_ok=True)
525+
base_export_dir = Path(base_export_dir)
526+
base_export_dir.mkdir(parents=True, exist_ok=True)
527+
523528
try:
524529
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
525530

526531
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
527532
# Save hf_quant_config.json for backward compatibility
528-
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
533+
with open(f"{base_export_dir}/hf_quant_config.json", "w") as file:
529534
json.dump(hf_quant_config, file, indent=4)
530535

531536
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
532537

533538
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
534539

535-
# For QLoRA models we export the base model
536-
if hasattr(model, "base_model"):
537-
model = model.base_model
540+
# In the case of LoRA model, we save the base model
541+
if is_lora:
542+
model.base_model.save_pretrained(
543+
base_export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
544+
)
545+
538546
model.save_pretrained(
539547
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
540548
)
541549

542-
original_config = f"{export_dir}/config.json"
550+
original_config = f"{base_export_dir}/config.json"
543551
config_data = {}
544552

545-
with open(original_config) as file:
546-
config_data = json.load(file)
553+
# In the case of LoRA model.save_pretrained does not save the correct config.json
554+
config_data = model.config.to_dict()
547555

548556
config_data["quantization_config"] = hf_quant_config
549557

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self.model, "peft_config"
148148
):
149149
# TODO: use get_peft_model here instead of add_adapter
150-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
150+
self.model.add_adapter(self.args.lora_config)
151151
print_rank_0("Lora adapter added.")
152152

153153
if hasattr(self.model, "peft_config") and self.quant_cfg is not None:
@@ -185,6 +185,7 @@ def _save_modelopt_state_with_weights(self):
185185
# Save base model compressed weights for QLoRA
186186
if getattr(self.quant_args, "compress", False):
187187
# Save base model config.json
188+
# weight_quantizer = self.quant_cfg["quant_cfg"]["*weight_quantizer"]
188189
self.model.config.save_pretrained(self.args.output_dir)
189190

190191
# Save base model compressed weights excluding lora weights
@@ -362,7 +363,7 @@ def __init__(
362363
if self.quant_cfg is not None and not is_quantized(self.model):
363364
self._quantize_model()
364365
if getattr(self.args, "lora_config", None) is not None:
365-
self.model.add_adapter(self.args.lora_config, adapter_name="adapter")
366+
self.model.add_adapter(self.args.lora_config)
366367
print_rank_0("Lora adapter added.")
367368
self._convert_to_distillation_model()
368369

0 commit comments

Comments
 (0)