Skip to content

Commit 3dc62c6

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
update spinquant quantization options to be general purposed pre-quantization (#5797)
Summary: Pull Request resolved: #5797 We decided to use the same quantization scheme and checkpoint format for QAT + LoRA. This PR updates related quantization cli options to be general purposed for pre-quantized checkpoints. Differential Revision: D63708762
1 parent 6923ae5 commit 3dc62c6

File tree

9 files changed

+240
-222
lines changed

9 files changed

+240
-222
lines changed

examples/models/llama2/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,13 @@ python -m examples.models.llama2.export_llama \
162162
--params "${LLAMA_PARAMS:?}" \
163163
--use_sdpa_with_kv_cache \
164164
-X \
165-
--spin_qmode 8da4w_output_8da8w \
166-
--spin_group_size 32 \
165+
--preq_mode 8da4w_output_8da8w \
166+
--preq_group_size 32 \
167167
--max_seq_length 2048 \
168168
--output_name "llama3_2.pte" \
169169
-kv \
170170
-d fp32 \
171-
--spin_embedding_quantize 8,0 \
171+
--preq_embedding_quantize 8,0 \
172172
--use_spin_quant native \
173173
--metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}'
174174
```

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ runtime.python_library(
8080
"export_llama_lib.py",
8181
"model.py",
8282
"source_transformation/apply_spin_quant_r1_r2.py",
83+
"source_transformation/pre_quantization.py",
8384
"source_transformation/prune_output.py",
8485
"source_transformation/quantize.py",
8586
"source_transformation/rms_norm.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,25 +381,25 @@ def build_args_parser() -> argparse.ArgumentParser:
381381
)
382382

383383
parser.add_argument(
384-
"--spin_qmode",
384+
"--preq_mode",
385385
type=str,
386386
default=None,
387387
choices=["8da4w", "8da4w_output_8da8w"],
388-
help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.",
388+
help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.",
389389
)
390390

391391
parser.add_argument(
392-
"--spin_group_size",
392+
"--preq_group_size",
393393
type=int,
394394
default=32,
395-
help="group_size for SpinQuant weight quantization",
395+
help="group_size for pre-quantized checkpoint weight quantization",
396396
)
397397

398398
parser.add_argument(
399-
"--spin_embedding_quantize",
399+
"--preq_embedding_quantize",
400400
default="8,0",
401401
type=str,
402-
help="type of embedding quantization for SpinQuant, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
402+
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
403403
)
404404

405405
parser.add_argument(

examples/models/llama2/model.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,20 +191,20 @@ def __init__(self, **kwargs):
191191
)
192192
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
193193
print("Using SPIN quantization.")
194-
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
195-
assert self.args.spin_qmode in [
194+
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
195+
assert self.args.preq_mode in [
196196
"8da4w",
197197
"8da4w_output_8da8w",
198-
], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant."
198+
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
199199
assert hasattr(
200-
self.args, "spin_group_size"
201-
), "spin_group_size must be specified"
200+
self.args, "preq_group_size"
201+
), "preq_group_size must be specified"
202202
assert hasattr(
203203
self.args, "dtype_override"
204204
), "dtype_override must be specified"
205-
from .source_transformation.spin_quant import (
206-
sanitize_checkpoint_from_spinquant,
207-
transform_linear_for_spinquant,
205+
from .source_transformation.pre_quantization import (
206+
sanitize_checkpoint_from_pre_quantization,
207+
transform_linear_for_pre_quantization,
208208
)
209209

210210
mapping = {
@@ -214,31 +214,31 @@ def __init__(self, **kwargs):
214214
}
215215

216216
# Transform the output layer first if needed.
217-
if self.args.spin_qmode == "8da4w_output_8da8w":
218-
from .source_transformation.spin_quant import (
219-
transform_output_linear_for_spinquant,
217+
if self.args.preq_mode == "8da4w_output_8da8w":
218+
from .source_transformation.pre_quantization import (
219+
transform_output_linear_for_pre_quantization,
220220
)
221221

222-
self.model_ = transform_output_linear_for_spinquant(
222+
self.model_ = transform_output_linear_for_pre_quantization(
223223
module=self.model_,
224224
checkpoint=checkpoint,
225225
dtype=mapping[self.args.dtype_override],
226226
)
227227

228-
self.model_ = transform_linear_for_spinquant(
228+
self.model_ = transform_linear_for_pre_quantization(
229229
self.model_,
230230
checkpoint,
231-
self.args.spin_group_size,
231+
self.args.preq_group_size,
232232
mapping[self.args.dtype_override],
233233
)
234234

235235
embedding_bit_width, embedding_group_size = None, None
236-
if hasattr(self.args, "spin_embedding_quantize"):
236+
if hasattr(self.args, "preq_embedding_quantize"):
237237
embedding_bit_width, embedding_group_size = (
238-
self.args.spin_embedding_quantize.split(",")
238+
self.args.preq_embedding_quantize.split(",")
239239
)
240-
from .source_transformation.spin_quant import (
241-
transform_embedding_for_spinquant,
240+
from .source_transformation.pre_quantization import (
241+
transform_embedding_for_pre_quantization,
242242
)
243243

244244
if (
@@ -250,15 +250,15 @@ def __init__(self, **kwargs):
250250
else:
251251
embedding_group_size = int(embedding_group_size)
252252

253-
self.model_ = transform_embedding_for_spinquant(
253+
self.model_ = transform_embedding_for_pre_quantization(
254254
self.model_,
255255
checkpoint,
256256
mapping[self.args.dtype_override],
257257
int(embedding_bit_width),
258258
embedding_group_size,
259259
)
260260

261-
sanitize_checkpoint_from_spinquant(checkpoint)
261+
sanitize_checkpoint_from_pre_quantization(checkpoint)
262262

263263
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
264264
# Because we are using device="meta", tensors do not have memory associated with them
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to load pre-quantized checkpoints.
10+
11+
from typing import Any, Optional
12+
13+
import torch
14+
from torch import nn
15+
16+
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
17+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
18+
19+
from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding
20+
21+
22+
def _replace_linear_with_linear_8da4w_for_pre_quantization(
23+
module: torch.nn.Module,
24+
checkpoint: Any,
25+
group_size: int,
26+
precision: torch.dtype,
27+
scales_precision: torch.dtype,
28+
):
29+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
30+
# Only replace linear layers where the checkpoint contains explicit scales
31+
scales_key = f"{cur_fqn}.scales"
32+
if isinstance(child, nn.Linear) and scales_key in checkpoint:
33+
assert _check_linear_int4_k(child.in_features, group_size)
34+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
35+
assert checkpoint[scales_key].dtype == scales_precision
36+
return True
37+
return False
38+
39+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
40+
new_linear = Int8DynActInt4WeightLinear(
41+
child.in_features,
42+
child.out_features,
43+
bias=False,
44+
device=child.weight.device,
45+
groupsize=group_size,
46+
precision=precision,
47+
scales_precision=scales_precision,
48+
)
49+
return new_linear
50+
51+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
52+
53+
54+
def transform_linear_for_pre_quantization(
55+
module: torch.nn.Module,
56+
checkpoint: Any,
57+
group_size: int,
58+
dtype: torch.dtype,
59+
) -> torch.nn.Module:
60+
"""
61+
Transform the model to be able to load pre-quantized checkpoints that
62+
are quantized with the given group size and quantization mode for
63+
linear layers.
64+
"""
65+
66+
if group_size not in [32, 64, 128, 256]:
67+
raise ValueError(
68+
f"Group size {group_size} is not supported for pre-quantized checkpoint."
69+
)
70+
_replace_linear_with_linear_8da4w_for_pre_quantization(
71+
module,
72+
checkpoint,
73+
group_size,
74+
dtype,
75+
dtype,
76+
)
77+
return module
78+
79+
80+
def _replace_output_linear_with_linear_int8_for_pre_quantization(
81+
module: torch.nn.Module,
82+
checkpoint: Any,
83+
dtype: torch.dtype,
84+
):
85+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
86+
scales_key = f"{cur_fqn}.scales"
87+
if (
88+
isinstance(child, nn.Linear)
89+
and scales_key in checkpoint
90+
and "output" in cur_fqn
91+
):
92+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
93+
assert checkpoint[scales_key].dtype == dtype
94+
return True
95+
return False
96+
97+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
98+
new_linear = Int8DynActInt8WeightLinear(
99+
device=child.weight.device,
100+
in_features=child.in_features,
101+
out_features=child.out_features,
102+
precision=dtype,
103+
bias=False,
104+
)
105+
return new_linear
106+
107+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
108+
109+
110+
def transform_output_linear_for_pre_quantization(
111+
module: torch.nn.Module,
112+
checkpoint: Any,
113+
dtype: torch.dtype,
114+
) -> torch.nn.Module:
115+
"""
116+
Transform the model to be able to load pre-quantized checkpoints that
117+
has the output layer quantized per-channel.
118+
"""
119+
_replace_output_linear_with_linear_int8_for_pre_quantization(
120+
module,
121+
checkpoint,
122+
dtype,
123+
)
124+
return module
125+
126+
127+
def _replace_embedding_with_quantized_group_embedding_for_pre_quantization(
128+
module: torch.nn.Module,
129+
checkpoint: Any,
130+
dtype: torch.dtype,
131+
bit_width: int,
132+
group_size: Optional[int] = None,
133+
):
134+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
135+
# Only replace embedding layers where the checkpoint contains explicit scales
136+
scales_key = f"{cur_fqn}.scales"
137+
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
138+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
139+
assert checkpoint[scales_key].dtype == torch.float32
140+
return True
141+
return False
142+
143+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
144+
new_embedding = QuantizedGroupEmbedding(
145+
device=child.weight.device,
146+
vocab_size=child.weight.shape[0],
147+
embedding_dim=child.weight.shape[1],
148+
group_size=group_size,
149+
dtype=dtype,
150+
packed=False, # TODO(lunwenh): support packed embedding for pre-quantized
151+
)
152+
return new_embedding
153+
154+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
155+
156+
157+
def transform_embedding_for_pre_quantization(
158+
module: torch.nn.Module,
159+
checkpoint: Any,
160+
dtype: torch.dtype,
161+
bit_width: int,
162+
group_size: Optional[int] = None,
163+
) -> torch.nn.Module:
164+
"""
165+
Transform the model to be able to load pre-quantized checkpoints that
166+
are quantized with the given bit_width and group size for embedding.
167+
"""
168+
if group_size is not None and group_size not in [0, 32, 64, 128, 256]:
169+
raise ValueError(
170+
f"Group size {group_size} is not supported for pre-quantized checkpoint."
171+
)
172+
_replace_embedding_with_quantized_group_embedding_for_pre_quantization(
173+
module,
174+
checkpoint,
175+
dtype,
176+
bit_width,
177+
group_size,
178+
)
179+
return module
180+
181+
182+
def sanitize_checkpoint_from_pre_quantization(
183+
checkpoint: Any,
184+
):
185+
"""
186+
Sanitize the pre-quantized checkpoint.
187+
- Converts all tensors to contiguous format
188+
- Squeeze all tensors
189+
"""
190+
for k, v in checkpoint.items():
191+
checkpoint[k] = torch.squeeze(v.contiguous())

0 commit comments

Comments
 (0)