Skip to content

Commit 2304e4d

Browse files
committed
e2e example for qlora ddp export
Signed-off-by: Suguna Velury <[email protected]>
1 parent 340eb7a commit 2304e4d

File tree

4 files changed

+48
-5
lines changed

4 files changed

+48
-5
lines changed

examples/llm_qat/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def train():
273273
kwargs = {"export_student": True} if training_args.distill else {}
274274
trainer.save_model(training_args.output_dir, **kwargs)
275275

276+
if training_args.lora and getattr(quant_args, "compress", False):
277+
trainer.export_base_model_hf_checkpoint()
278+
276279

277280
if __name__ == "__main__":
278281
train()

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,9 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
832832
"k_bmm_quantizer._bias_value": "k_proj.k_bias",
833833
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
834834
"input_quantizer._pre_quant_scale": "pre_quant_scale",
835+
"base_layer.weight": "weight",
836+
"base_layer.input_scale": "input_scale",
837+
"base_layer.weight_scale": "weight_scale",
835838
}
836839

837840
post_state_dict = {}
@@ -843,6 +846,7 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
843846
and "_amax" not in key
844847
and "_bias_value" not in key
845848
and "input_quantizer._pre_quant_scale" not in key
849+
and "base_layer" not in key
846850
):
847851
post_state_dict[key] = value
848852
continue
@@ -894,6 +898,11 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
894898
):
895899
keys_to_delete.append(key)
896900

901+
# remove LoRA adapters from state dict
902+
for key, value in post_state_dict.items():
903+
if "lora" in key and key not in keys_to_delete:
904+
keys_to_delete.append(key)
905+
897906
# Check for tied weights and remove duplicates
898907
seen_tensors = {}
899908

modelopt/torch/export/unified_export_hf.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3131
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
32-
from modelopt.torch.quantization.qtensor import NVFP4QTensor
32+
from modelopt.torch.quantization.qtensor import NVFP4QTensor, QTensorWrapper
3333
from modelopt.torch.quantization.utils import quantizer_attr_names
3434

3535
from .convert_hf_config import convert_hf_quant_config_format
@@ -85,6 +85,9 @@ def _is_enabled_quantizer(quantizer):
8585

8686
def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
8787
"""Group modules that take the same input and register shared parameters in module."""
88+
# Skip for LoRA finetuned models
89+
if hasattr(model, "base_model"):
90+
return
8891
# TODO: Handle DBRX MoE
8992
input_to_linear = defaultdict(list)
9093
output_to_layernorm = defaultdict(None)
@@ -311,7 +314,7 @@ def _export_quantized_weight(
311314
)[0]
312315

313316
quantized_weight = to_quantized_weight(
314-
weight.to(dtype),
317+
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
315318
weight_scale,
316319
quantization_format,
317320
weight_scale_2,
@@ -323,7 +326,7 @@ def _export_quantized_weight(
323326
)
324327
else:
325328
quantized_weight = to_quantized_weight(
326-
weight.to(dtype),
329+
weight.to(dtype) if not isinstance(weight, QTensorWrapper) else weight,
327330
weight_scale,
328331
quantization_format,
329332
weight_scale_2,
@@ -461,7 +464,11 @@ def _export_hf_checkpoint(
461464
for name, sub_module in layer_pool.items():
462465
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
463466
has_quantized_layers = True
464-
if is_quantlinear(sub_module):
467+
if (
468+
is_quantlinear(sub_module)
469+
and hasattr(sub_module, "weight_quantizer")
470+
and sub_module.weight_quantizer.is_enabled
471+
):
465472
_export_quantized_weight(sub_module, dtype)
466473
elif (
467474
"Llama4TextExperts" in type(sub_module).__name__
@@ -523,7 +530,9 @@ def export_hf_checkpoint(
523530

524531
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
525532

526-
# Save model
533+
# For QLoRA models we export the base model
534+
if hasattr(model, "base_model"):
535+
model = model.base_model
527536
model.save_pretrained(
528537
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
529538
)

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""ModelOpt plugin for transformers Trainer."""
1717

1818
import gc
19+
import json
1920
import os
2021
import types
2122
from dataclasses import dataclass, field
@@ -28,6 +29,7 @@
2829
from modelopt.torch.distill import KDLossConfig
2930
from modelopt.torch.distill.mode import _convert_for_kd
3031
from modelopt.torch.distill.plugins.huggingface import KDTrainer
32+
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
3133
from modelopt.torch.opt.conversion import restore_from_modelopt_state
3234
from modelopt.torch.opt.plugins import ModelOptHFTrainer
3335
from modelopt.torch.quantization.config import QuantizeConfig
@@ -217,6 +219,7 @@ def forward_loop(model):
217219
gc.collect()
218220

219221
self._save_modelopt_state_with_weights()
222+
220223
torch.cuda.empty_cache()
221224

222225
if self.accelerator.is_main_process:
@@ -275,6 +278,25 @@ def save_model(self, *args, **kwargs):
275278
outputs = super().save_model(*args, **kwargs)
276279
return outputs
277280

281+
def _load_best_model(self, *args, **kwargs):
282+
"""Load the best model."""
283+
is_lora = getattr(self.args, "lora", None)
284+
if not is_lora:
285+
super()._load_best_model(*args, **kwargs)
286+
else:
287+
# Custom logic for loading best model with LoRA
288+
adapter_name = self.model.active_adapter()
289+
self.model.delete_adapter(adapter_name)
290+
self.model.load_adapter(self.state.best_model_checkpoint, adapter_name)
291+
292+
def export_base_model_hf_checkpoint(self):
293+
"""Export the basemodel to HF checkpoint for deployment."""
294+
# Save config.json
295+
if self.accelerator.is_main_process:
296+
with open(f"{self.args.output_dir}/config.json", "w") as f:
297+
json.dump(self.model.config.to_dict(), f, indent=2)
298+
export_hf_checkpoint(self.model, export_dir=f"{self.args.output_dir}/base_model")
299+
278300
def _patch_accelerate_for_fsdp2_fix(self):
279301
"""Fixes for accelerate prepare.
280302

0 commit comments

Comments
 (0)