Skip to content

Commit db3bfb6

Browse files
committed
add support for loading qat_lora checkpoints
This PR adds support to load qat_lora checkpoints. It mainly does the following two things: - Refactor the existing quantization flow for SpinQuant to be separate function, which is used to load QAT checkpoint as well since they share the same format. - For QAT_LoRA checkpoint, we do one more extra step after quantization. It replaces `Int8DynActInt4WeightLinear` layers with `Int8DynActInt4WeightLinearLoRA` which contains LoRA adaptor. Differential Revision: [D63714794](https://our.internmc.facebook.com/intern/diff/D63714794/) ghstack-source-id: 245945707 Pull Request resolved: #5823
1 parent 152e22d commit db3bfb6

File tree

4 files changed

+234
-57
lines changed

4 files changed

+234
-57
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ runtime.python_library(
8080
"export_llama_lib.py",
8181
"model.py",
8282
"source_transformation/apply_spin_quant_r1_r2.py",
83+
"source_transformation/lora.py",
8384
"source_transformation/pre_quantization.py",
8485
"source_transformation/prune_output.py",
8586
"source_transformation/quantize.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,23 @@ def build_args_parser() -> argparse.ArgumentParser:
390390
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
391391
)
392392

393+
parser.add_argument(
394+
"-qat",
395+
"--use_qat",
396+
default=False,
397+
action="store_true",
398+
help="Whether the checkpoin is pre-quantized with QAT or not.",
399+
)
400+
401+
parser.add_argument(
402+
"-lora",
403+
"--use_lora",
404+
type=int,
405+
default=0,
406+
help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; "
407+
"otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.",
408+
)
409+
393410
parser.add_argument(
394411
"--preq_mode",
395412
type=str,

examples/models/llama2/model.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414

1515
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
16+
from executorch.extension.llm.export.builder import DType
1617

1718
try:
1819
from .fairseq2 import convert_to_llama_checkpoint
@@ -191,73 +192,31 @@ def __init__(self, **kwargs):
191192
)
192193
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
193194
print("Using SPIN quantization.")
194-
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
195-
assert self.args.preq_mode in [
196-
"8da4w",
197-
"8da4w_output_8da8w",
198-
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
199-
assert hasattr(
200-
self.args, "preq_group_size"
201-
), "preq_group_size must be specified"
202-
assert hasattr(
203-
self.args, "dtype_override"
204-
), "dtype_override must be specified"
195+
self._transform_for_pre_quantization(checkpoint)
196+
205197
from .source_transformation.pre_quantization import (
206198
sanitize_checkpoint_from_pre_quantization,
207-
transform_linear_for_pre_quantization,
208-
)
209-
210-
mapping = {
211-
"fp32": torch.float32,
212-
"fp16": torch.float16,
213-
"bf16": torch.bfloat16,
214-
}
215-
216-
# Transform the output layer first if needed.
217-
if self.args.preq_mode == "8da4w_output_8da8w":
218-
from .source_transformation.pre_quantization import (
219-
transform_output_linear_for_pre_quantization,
220-
)
221-
222-
self.model_ = transform_output_linear_for_pre_quantization(
223-
module=self.model_,
224-
checkpoint=checkpoint,
225-
dtype=mapping[self.args.dtype_override],
226-
)
227-
228-
self.model_ = transform_linear_for_pre_quantization(
229-
self.model_,
230-
checkpoint,
231-
self.args.preq_group_size,
232-
mapping[self.args.dtype_override],
233199
)
234200

235-
embedding_bit_width, embedding_group_size = None, None
236-
if hasattr(self.args, "preq_embedding_quantize"):
237-
embedding_bit_width, embedding_group_size = (
238-
self.args.preq_embedding_quantize.split(",")
239-
)
240-
from .source_transformation.pre_quantization import (
241-
transform_embedding_for_pre_quantization,
201+
sanitize_checkpoint_from_pre_quantization(checkpoint)
202+
elif hasattr(self.args, "use_qat") and self.args.use_qat:
203+
print("Using QAT quantization.")
204+
self._transform_for_pre_quantization(checkpoint)
205+
if hasattr(self.args, "use_lora") and self.args.use_lora:
206+
from .source_transformation.lora import (
207+
transform_linear_for_lora_after_quantization,
242208
)
243209

244-
if (
245-
embedding_group_size == "none"
246-
or embedding_group_size == "None"
247-
or embedding_group_size == "0"
248-
):
249-
embedding_group_size = None
250-
else:
251-
embedding_group_size = int(embedding_group_size)
252-
253-
self.model_ = transform_embedding_for_pre_quantization(
210+
self.model_ = transform_linear_for_lora_after_quantization(
254211
self.model_,
255212
checkpoint,
256-
mapping[self.args.dtype_override],
257-
int(embedding_bit_width),
258-
embedding_group_size,
213+
self.args.use_lora,
259214
)
260215

216+
from .source_transformation.pre_quantization import (
217+
sanitize_checkpoint_from_pre_quantization,
218+
)
219+
261220
sanitize_checkpoint_from_pre_quantization(checkpoint)
262221

263222
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
@@ -318,3 +277,62 @@ def get_example_inputs_kvcache_sdpa(self):
318277
[0], dtype=torch.long
319278
), # start_pos, what token of output are we on.
320279
)
280+
281+
def _transform_for_pre_quantization(self, checkpoint):
282+
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
283+
assert self.args.preq_mode in [
284+
"8da4w",
285+
"8da4w_output_8da8w",
286+
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
287+
assert hasattr(
288+
self.args, "preq_group_size"
289+
), "preq_group_size must be specified"
290+
assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
291+
from .source_transformation.pre_quantization import (
292+
transform_linear_for_pre_quantization,
293+
)
294+
295+
# Transform the output layer first if needed.
296+
if self.args.preq_mode == "8da4w_output_8da8w":
297+
from .source_transformation.pre_quantization import (
298+
transform_output_linear_for_pre_quantization,
299+
)
300+
301+
self.model_ = transform_output_linear_for_pre_quantization(
302+
module=self.model_,
303+
checkpoint=checkpoint,
304+
dtype=DType[self.args.dtype_override].to_torch_dtype(),
305+
)
306+
307+
self.model_ = transform_linear_for_pre_quantization(
308+
self.model_,
309+
checkpoint,
310+
self.args.preq_group_size,
311+
DType[self.args.dtype_override].to_torch_dtype(),
312+
)
313+
314+
embedding_bit_width, embedding_group_size = None, None
315+
if hasattr(self.args, "preq_embedding_quantize"):
316+
embedding_bit_width, embedding_group_size = (
317+
self.args.preq_embedding_quantize.split(",")
318+
)
319+
from .source_transformation.pre_quantization import (
320+
transform_embedding_for_pre_quantization,
321+
)
322+
323+
if (
324+
embedding_group_size == "none"
325+
or embedding_group_size == "None"
326+
or embedding_group_size == "0"
327+
):
328+
embedding_group_size = None
329+
else:
330+
embedding_group_size = int(embedding_group_size)
331+
332+
self.model_ = transform_embedding_for_pre_quantization(
333+
self.model_,
334+
checkpoint,
335+
DType[self.args.dtype_override].to_torch_dtype(),
336+
int(embedding_bit_width),
337+
embedding_group_size,
338+
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to load checkpoints with
10+
# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA.
11+
12+
from typing import Any
13+
14+
import torch
15+
from torch import nn
16+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
17+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
18+
19+
20+
class LoRAAdaptorLinear(nn.Module):
21+
"""
22+
LoRA adaptor for linear layers.
23+
24+
This class implements Low-Rank Adaptation(LoRA) for linear layers.
25+
See more details about LoRA here https://arxiv.org/abs/2106.09685.
26+
"""
27+
28+
def __init__(
29+
self,
30+
in_features: int,
31+
out_features: int,
32+
rank: int,
33+
scale: float = 2.0,
34+
dtype=torch.float32,
35+
device=None,
36+
) -> None:
37+
super().__init__()
38+
self.scale = scale
39+
self.A = nn.Linear(in_features, rank, bias=False, dtype=dtype, device=device)
40+
self.B = nn.Linear(rank, out_features, bias=False, dtype=dtype, device=device)
41+
42+
def forward(self, x: torch.Tensor) -> torch.Tensor:
43+
return self.scale * self.B(self.A(x)) # pyre-ignore[7]
44+
45+
46+
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
47+
"""
48+
Int8DynActInt4WeightLinear with LoRA adaptor.
49+
"""
50+
51+
def __init__(
52+
self,
53+
in_features: int,
54+
out_features: int,
55+
lora_rank: int,
56+
bias=True,
57+
device=None,
58+
groupsize: int = 256,
59+
precision: torch.dtype = torch.float32,
60+
scales_precision: torch.dtype = torch.float32,
61+
lora_adaptor_precision: torch.dtype = torch.bfloat16,
62+
lora_scale: float = 2.0,
63+
) -> None:
64+
super().__init__(
65+
in_features,
66+
out_features,
67+
bias=bias,
68+
device=device,
69+
groupsize=groupsize,
70+
precision=precision,
71+
scales_precision=scales_precision,
72+
)
73+
self.adaptor = LoRAAdaptorLinear(
74+
in_features,
75+
out_features,
76+
lora_rank,
77+
scale=lora_scale,
78+
dtype=lora_adaptor_precision,
79+
device=device,
80+
)
81+
82+
def forward(self, input: torch.Tensor) -> torch.Tensor:
83+
return super().forward(input) + self.adaptor(input).to(dtype=self.precision)
84+
85+
86+
def _replace_linear_8da4w_for_lora(
87+
module: torch.nn.Module,
88+
checkpoint: Any,
89+
lora_rank: int,
90+
):
91+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
92+
# Only replace linear layers where the checkpoint contains explicit adaptors
93+
adaptor_A_key = f"{cur_fqn}.adaptor.A.weight"
94+
adaptor_B_key = f"{cur_fqn}.adaptor.B.weight"
95+
if (
96+
isinstance(child, Int8DynActInt4WeightLinear)
97+
and adaptor_A_key in checkpoint
98+
and adaptor_B_key in checkpoint
99+
):
100+
assert checkpoint[adaptor_A_key].dtype == torch.bfloat16
101+
assert checkpoint[adaptor_A_key].shape[0] == lora_rank
102+
assert checkpoint[adaptor_A_key].shape[1] == child.in_features
103+
assert checkpoint[adaptor_B_key].dtype == torch.bfloat16
104+
assert checkpoint[adaptor_B_key].shape[0] == child.out_features
105+
assert checkpoint[adaptor_B_key].shape[1] == lora_rank
106+
return True
107+
return False
108+
109+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
110+
new_linear = Int8DynActInt4WeightLinearLoRA(
111+
child.in_features,
112+
child.out_features,
113+
lora_rank=lora_rank,
114+
bias=False,
115+
device=child.weight.device,
116+
groupsize=child.groupsize,
117+
precision=child.precision,
118+
scales_precision=child.scales.dtype,
119+
)
120+
return new_linear
121+
122+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
123+
124+
125+
def transform_linear_for_lora_after_quantization(
126+
module: torch.nn.Module,
127+
checkpoint: Any,
128+
lora_rank: int,
129+
) -> torch.nn.Module:
130+
"""
131+
Transform the model to be able to load checkpoints with LoRA adaptors.
132+
The model should be already transformed to be able to load pre-quantized
133+
checkpoints. The checkpoint should have been pre-quantized and added with
134+
LoRA adaptors.
135+
"""
136+
_replace_linear_8da4w_for_lora(
137+
module,
138+
checkpoint,
139+
lora_rank,
140+
)
141+
return module

0 commit comments

Comments
 (0)