Skip to content

Commit 035117f

Browse files
committed
e2e example for qlora ddp export
1 parent 74061f5 commit 035117f

File tree

4 files changed

+48
-5
lines changed

4 files changed

+48
-5
lines changed

examples/llm_qat/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def train():
273273
kwargs = {"export_student": True} if training_args.distill else {}
274274
trainer.save_model(training_args.output_dir, **kwargs)
275275

276+
if training_args.lora and getattr(quant_args, "compress", False):
277+
trainer.export_base_model_hf_checkpoint()
278+
276279

277280
if __name__ == "__main__":
278281
train()

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,9 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
826826
"k_bmm_quantizer._bias_value": "k_proj.k_bias",
827827
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
828828
"input_quantizer._pre_quant_scale": "pre_quant_scale",
829+
"base_layer.weight": "weight",
830+
"base_layer.input_scale": "input_scale",
831+
"base_layer.weight_scale": "weight_scale",
829832
}
830833

831834
post_state_dict = {}
@@ -837,6 +840,7 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
837840
and "_amax" not in key
838841
and "_bias_value" not in key
839842
and "input_quantizer._pre_quant_scale" not in key
843+
and "base_layer" not in key
840844
):
841845
post_state_dict[key] = value
842846
continue
@@ -888,6 +892,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
888892
):
889893
keys_to_delete.append(key)
890894

895+
# remove LoRA adapters from state dict
896+
for key, value in post_state_dict.items():
897+
if "lora" in key and key not in keys_to_delete:
898+
keys_to_delete.append(key)
899+
891900
# Check for tied weights and remove duplicates
892901
seen_tensors = {}
893902

modelopt/torch/export/unified_export_hf.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3131
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
32-
from modelopt.torch.quantization.qtensor import NVFP4QTensor
32+
from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper
3333
from modelopt.torch.quantization.utils import quantizer_attr_names
3434

3535
from .convert_hf_config import convert_hf_quant_config_format
@@ -85,6 +85,9 @@ def _is_enabled_quantizer(quantizer):
8585

8686
def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
8787
"""Group modules that take the same input and register shared parameters in module."""
88+
# Skip for LoRA finetuned models
89+
if hasattr(model, "base_model"):
90+
return
8891
# TODO: Handle DBRX MoE
8992
input_to_linear = defaultdict(list)
9093
output_to_layernorm = defaultdict(None)
@@ -311,7 +314,7 @@ def _export_quantized_weight(
311314
)[0]
312315

313316
quantized_weight = to_quantized_weight(
314-
weight.to(dtype),
317+
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
315318
weight_scale,
316319
quantization_format,
317320
weight_scale_2,
@@ -323,7 +326,7 @@ def _export_quantized_weight(
323326
)
324327
else:
325328
quantized_weight = to_quantized_weight(
326-
weight.to(dtype),
329+
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
327330
weight_scale,
328331
quantization_format,
329332
weight_scale_2,
@@ -457,7 +460,11 @@ def _export_hf_checkpoint(
457460
for name, sub_module in layer_pool.items():
458461
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
459462
has_quantized_layers = True
460-
if is_quantlinear(sub_module):
463+
if (
464+
is_quantlinear(sub_module)
465+
and hasattr(sub_module, "weight_quantizer")
466+
and sub_module.weight_quantizer.is_enabled
467+
):
461468
_export_quantized_weight(sub_module, dtype)
462469
elif (
463470
"Llama4TextExperts" in type(sub_module).__name__
@@ -519,7 +526,9 @@ def export_hf_checkpoint(
519526

520527
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
521528

522-
# Save model
529+
# For QLoRA models we export the base model
530+
if hasattr(model, "base_model"):
531+
model = model.base_model
523532
model.save_pretrained(
524533
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
525534
)

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""ModelOpt plugin for transformers Trainer."""
1717

1818
import gc
19+
import json
1920
import os
2021
import types
2122
from dataclasses import dataclass, field
@@ -28,6 +29,7 @@
2829
from modelopt.torch.distill import KDLossConfig
2930
from modelopt.torch.distill.mode import _convert_for_kd
3031
from modelopt.torch.distill.plugins.huggingface import KDTrainer
32+
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
3133
from modelopt.torch.opt.conversion import restore_from_modelopt_state
3234
from modelopt.torch.opt.plugins import ModelOptHFTrainer
3335
from modelopt.torch.quantization.config import QuantizeConfig
@@ -217,6 +219,7 @@ def forward_loop(model):
217219
gc.collect()
218220

219221
self._save_modelopt_state_with_weights()
222+
220223
torch.cuda.empty_cache()
221224

222225
if self.accelerator.is_main_process:
@@ -275,6 +278,25 @@ def save_model(self, *args, **kwargs):
275278
outputs = super().save_model(*args, **kwargs)
276279
return outputs
277280

281+
def _load_best_model(self, *args, **kwargs):
282+
"""Load the best model."""
283+
is_lora = getattr(self.args, "lora", None)
284+
if not is_lora:
285+
super()._load_best_model(*args, **kwargs)
286+
else:
287+
# Custom logic for loading best model with LoRA
288+
adapter_name = self.model.active_adapter()
289+
self.model.delete_adapter(adapter_name)
290+
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
291+
292+
def export_base_model_hf_checkpoint(self):
293+
"""Export the basemodel to HF checkpoint for deployment."""
294+
# Save config.json
295+
if self.accelerator.is_main_process:
296+
with open(f"{self.args.output_dir}/config.json", "w") as f:
297+
json.dump(self.model.config.to_dict(), f, indent=2)
298+
export_hf_checkpoint(self.model, export_dir=f"{self.args.output_dir}/base_model")
299+
278300
def _patch_accelerate_for_fsdp2_fix(self):
279301
"""Fixes for accelerate prepare.
280302

0 commit comments

Comments
 (0)