Skip to content

Commit 341545c

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
add option to quantize output layer perchannel for SpinQuant (#5614)
Summary: Pull Request resolved: pytorch/executorch#5614 Add an option to optionally quantize the output layer int8 per-channel Reviewed By: mergennachin, iseeyuan Differential Revision: D62787491 fbshipit-source-id: cc86a9105966dddbdfb26f77c62a6e0f9c01d24c
1 parent 6f9cd8c commit 341545c

File tree

4 files changed

+126
-14
lines changed

4 files changed

+126
-14
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,8 @@ def build_args_parser() -> argparse.ArgumentParser:
374374
"--spin_qmode",
375375
type=str,
376376
default=None,
377-
choices=["8da4w"],
378-
help="Quantization mode for SpinQuant. Only support 8da4w right now.",
377+
choices=["8da4w", "8da4w_output_8da8w"],
378+
help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.",
379379
)
380380

381381
parser.add_argument(

examples/models/llama2/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ def __init__(self, **kwargs):
192192
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
193193
print("Using SPIN quantization.")
194194
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
195+
assert self.args.spin_qmode in [
196+
"8da4w",
197+
"8da4w_output_8da8w",
198+
], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant."
195199
assert hasattr(
196200
self.args, "spin_group_size"
197201
), "spin_group_size must be specified"
@@ -209,11 +213,22 @@ def __init__(self, **kwargs):
209213
"bf16": torch.bfloat16,
210214
}
211215

216+
# Transform the output layer first if needed.
217+
if self.args.spin_qmode == "8da4w_output_8da8w":
218+
from .source_transformation.spin_quant import (
219+
transform_output_linear_for_spinquant,
220+
)
221+
222+
self.model_ = transform_output_linear_for_spinquant(
223+
module=self.model_,
224+
checkpoint=checkpoint,
225+
dtype=mapping[self.args.dtype_override],
226+
)
227+
212228
self.model_ = transform_linear_for_spinquant(
213229
self.model_,
214230
checkpoint,
215231
self.args.spin_group_size,
216-
self.args.spin_qmode,
217232
mapping[self.args.dtype_override],
218233
)
219234

examples/models/llama2/source_transformation/spin_quant.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
2121
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
2222

23-
from .quantize import QuantizedGroupEmbedding
23+
from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding
2424

2525

2626
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
@@ -129,20 +129,16 @@ def transform_linear_for_spinquant(
129129
module: torch.nn.Module,
130130
checkpoint: Any,
131131
group_size: int,
132-
quantization_mode: str,
133132
dtype: torch.dtype,
134133
) -> torch.nn.Module:
135134
"""
136135
Transform the model to be able to load SpinQuant checkpoints that
137-
are quantized with the given group size and quantization mode.
136+
are quantized with the given group size and quantization mode for
137+
linear layers.
138138
"""
139139

140140
if group_size not in [32, 64, 128, 256]:
141141
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
142-
if quantization_mode not in ["8da4w"]:
143-
raise ValueError(
144-
f"Quantization mode {quantization_mode} is not compatible with SpinQuant."
145-
)
146142
_replace_linear_with_linear_8da4w_for_spin_quant(
147143
module,
148144
checkpoint,
@@ -153,6 +149,53 @@ def transform_linear_for_spinquant(
153149
return module
154150

155151

152+
def _replace_output_linear_with_linear_int8_for_spinquant(
153+
module: torch.nn.Module,
154+
checkpoint: Any,
155+
dtype: torch.dtype,
156+
):
157+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
158+
scales_key = f"{cur_fqn}.scale"
159+
if (
160+
isinstance(child, nn.Linear)
161+
and scales_key in checkpoint
162+
and "output" in cur_fqn
163+
):
164+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
165+
assert checkpoint[scales_key].dtype == dtype
166+
return True
167+
return False
168+
169+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
170+
new_linear = Int8DynActInt8WeightLinear(
171+
device=child.weight.device,
172+
in_features=child.in_features,
173+
out_features=child.out_features,
174+
precision=dtype,
175+
bias=False,
176+
)
177+
return new_linear
178+
179+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
180+
181+
182+
def transform_output_linear_for_spinquant(
183+
module: torch.nn.Module,
184+
checkpoint: Any,
185+
dtype: torch.dtype,
186+
) -> torch.nn.Module:
187+
"""
188+
Transform the model to be able to load SpinQuant checkpoints that
189+
has the output layer quantized per-channel.
190+
"""
191+
_replace_output_linear_with_linear_int8_for_spinquant(
192+
module,
193+
checkpoint,
194+
dtype,
195+
)
196+
return module
197+
198+
156199
def _replace_embedding_with_quantized_group_embedding_for_spinquant(
157200
module: torch.nn.Module,
158201
checkpoint: Any,
@@ -233,8 +276,10 @@ def sanitize_checkpoint_from_spinquant(
233276
module_name = new_key[0 : new_key.rfind(".")]
234277
sub_module = module.get_submodule(module_name)
235278
assert sub_module is not None
236-
assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance(
237-
sub_module, QuantizedGroupEmbedding
279+
assert (
280+
isinstance(sub_module, Int8DynActInt4WeightLinear)
281+
or isinstance(sub_module, QuantizedGroupEmbedding)
282+
or isinstance(sub_module, Int8DynActInt8WeightLinear)
238283
)
239284
# Checkpoints with SpinQuant could come with two formats for scales:
240285
# 1. scales is grouped by group size
@@ -245,6 +290,8 @@ def sanitize_checkpoint_from_spinquant(
245290
checkpoint[new_key] = (
246291
old_val if linear_group_size == -1 else old_val[:, ::linear_group_size]
247292
)
293+
elif isinstance(sub_module, Int8DynActInt8WeightLinear):
294+
checkpoint[new_key] = old_val[:, 0]
248295
elif isinstance(sub_module, QuantizedGroupEmbedding):
249296
if (
250297
embedding_group_size is None or embedding_group_size == 0

examples/models/llama2/tests/test_spinquant_transforms.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
import torch
1010
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
11+
from executorch.examples.models.llama2.source_transformation.quantize import (
12+
dynamically_quantize_per_channel,
13+
)
1114
from executorch.examples.models.llama2.source_transformation.spin_quant import (
1215
sanitize_checkpoint_from_spinquant,
1316
transform_embedding_for_spinquant,
1417
transform_linear_for_spinquant,
18+
transform_output_linear_for_spinquant,
1519
)
1620
from torchao.quantization.utils import group_quantize_tensor_symmetric
1721

@@ -51,8 +55,7 @@ def test_transform_linear_for_spinquant(self):
5155
n_bit = 4
5256
scales_precision = torch.float32
5357
for fqn, mod in model.named_modules():
54-
# Quantize everything except the last layer
55-
if isinstance(mod, torch.nn.Linear) and ("output" not in fqn):
58+
if isinstance(mod, torch.nn.Linear):
5659
weight = mod.weight.data
5760
(
5861
weight_int8,
@@ -92,6 +95,53 @@ def test_transform_linear_for_spinquant(self):
9295
# have to iterate over the keys.
9396
self.assertTrue(torch.allclose(new_checkpoint[k], v))
9497

98+
def test_transform_output_linear_for_spinquant(self):
99+
# Step 1: Create llama class with dummy weights
100+
model = self._prepare_dummy_model()
101+
checkpoint = model.state_dict()
102+
103+
# Step 2:
104+
# Do per-channel quantization and amend the checkpoints with
105+
# int8 weight and fp32 scales
106+
for fqn, mod in model.named_modules():
107+
if isinstance(mod, torch.nn.Linear) and fqn == "output":
108+
weight = mod.weight.data
109+
weight_int8, scales, _ = dynamically_quantize_per_channel(
110+
weight,
111+
quant_min=-128,
112+
quant_max=127,
113+
target_dtype=torch.int8,
114+
scales_dtype=torch.float32,
115+
)
116+
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
117+
checkpoint[f"{fqn}.scale"] = scales.to("cpu")
118+
119+
# Step 3:
120+
# Transform the model so that it is compatible with the new checkpoint
121+
transform_output_linear_for_spinquant(
122+
model,
123+
checkpoint,
124+
torch.float32,
125+
)
126+
sanitize_checkpoint_from_spinquant(
127+
model,
128+
checkpoint,
129+
-1,
130+
)
131+
132+
model.load_state_dict(
133+
checkpoint,
134+
strict=False,
135+
assign=True,
136+
)
137+
138+
new_checkpoint = model.state_dict()
139+
140+
for k, v in checkpoint.items():
141+
# The new_checkpoint contains zeros so
142+
# have to iterate over the keys.
143+
self.assertTrue(torch.allclose(new_checkpoint[k], v))
144+
95145
def test_transform_embedding_for_spinquant(self):
96146

97147
# Step 1: Create llama class with dummy weights

0 commit comments

Comments
 (0)