Skip to content

Commit f288508

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Range setting pt2e
Summary: The code I wrote consists of: 1) Two observer classes 2) Functions to reverse the module swaps 3) Function to compute scales (which can be called from the script) and function to set scales manually in Qualcomm LlamaModel This code is a draft, I would like early feedback before adding unit tests. Rollback Plan: Differential Revision: D77680635
1 parent defa089 commit f288508

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
"""
9+
The goal of this is to allow range setting methods from TorchAO (formerly Quanty)
10+
to be incorporated into the PT2E flow.
11+
12+
We implement the two main range setting methods:
13+
1) MSE weight range setting (via a custom observer)
14+
2) Activation loss weight range setting (via precomputing scales with Quanty, and loading them into a manual observer)
15+
16+
"""
17+
import sys
18+
import logging
19+
20+
import torch
21+
import torch.nn as nn
22+
import torch.nn.functional as F
23+
from executorch.backends.qualcomm.quantizer.annotators import OP_ANNOTATOR
24+
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
25+
PerChannelParamObserver,
26+
)
27+
28+
from executorch.backends.qualcomm.quantizer.qconfig import (
29+
_derived_bias_quant_spec,
30+
QuantizationConfig,
31+
)
32+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
33+
34+
from executorch.examples.qualcomm.utils import make_quantizer
35+
36+
from torchao.prototype.quantization.module_swap import (
37+
QuantizationRecipe,
38+
quantize_module_swap,
39+
QuantizedLinear,
40+
)
41+
from torchao.prototype.quantization.module_swap.module_swap import (
42+
get_layer_parent_by_name,
43+
)
44+
from torchao.prototype.quantization.module_swap.quantized_modules import (
45+
QuantizedEmbedding,
46+
)
47+
from torchao.prototype.quantization.module_swap.range_setting_methods import (
48+
set_weight_range_activation_loss,
49+
)
50+
51+
from torchao.quantization.pt2e import (
52+
HistogramObserver,
53+
MinMaxObserver,
54+
ObserverBase,
55+
PerChannelMinMaxObserver,
56+
)
57+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
58+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
59+
60+
61+
class PerChannelMSEObserver(PerChannelParamObserver):
62+
63+
@torch.jit.export
64+
def forward(self, x_orig):
65+
# since params are static, one calibration is enough
66+
if not self.calibrated:
67+
x = x_orig.detach().to(self.min_val.dtype)
68+
self.min_val, self.max_val = self.line_search(x)
69+
self.calibrated = True
70+
71+
return x_orig
72+
73+
74+
75+
class PerChannelFixedQParamsObserver(PerChannelMinMaxObserver):
76+
r"""
77+
Fixed scale that is set manually. Symmetric quantization, so zero point is always zero
78+
Used for per channel quantization
79+
If scale not set, defaults to minmax
80+
"""
81+
82+
def __init__(
83+
self,
84+
ch_axis=0,
85+
dtype=torch.quint8,
86+
qscheme=torch.per_channel_symmetric,
87+
quant_min=0,
88+
quant_max=255,
89+
is_dynamic=False,
90+
**kwargs,
91+
):
92+
super().__init__(ch_axis=ch_axis, dtype=dtype, qscheme=qscheme, is_dynamic=is_dynamic, **kwargs)
93+
self.quant_min = quant_min
94+
self.quant_max = quant_max
95+
96+
def set_scale(self, scale, device):
97+
self.scale = scale.to(device=device)
98+
self.zero_point = torch.zeros_like(scale).to(device=device)
99+
100+
@torch.jit.export
101+
def calculate_qparams(self):
102+
if hasattr(self, "scale"):
103+
return self.scale, self.zero_point
104+
return self._calculate_qparams(self.min_val, self.max_val)
105+
106+
107+
def reverse_quantize_module_swap(model: nn.Module) -> nn.Module:
108+
"""
109+
Reverse `quantize_module_swap`
110+
QuantizedLinear --> Linear
111+
QuantizedEmbedding --> Embedding
112+
"""
113+
model = reverse_replace_all_linear_with_quantized(model)
114+
model = reverse_replace_all_embedding_with_quantized(model)
115+
return model
116+
117+
118+
def reverse_replace_all_embedding_with_quantized(
119+
model: nn.Module
120+
) -> nn.Module:
121+
"""
122+
Reverse `replace_all_embedding_with_quantized`
123+
QuantizedEmbedding --> Embedding
124+
"""
125+
for name, module in model.named_modules():
126+
if isinstance(module, QuantizedEmbedding):
127+
embedding = nn.Embedding(
128+
num_embeddings=module.num_embeddings,
129+
embedding_dim=module.embedding_dim,
130+
padding_idx=module.padding_idx,
131+
max_norm=module.max_norm,
132+
norm_type=module.norm_type,
133+
scale_grad_by_freq=module.scale_grad_by_freq,
134+
sparse=module.sparse,
135+
_weight=module.weight,
136+
)
137+
attribute_name = name.rsplit(".", 1)[-1]
138+
parent_of_module = get_layer_parent_by_name(model, name)
139+
setattr(parent_of_module, attribute_name, embedding)
140+
141+
return model
142+
143+
144+
def reverse_replace_all_linear_with_quantized(
145+
model: nn.Module,
146+
) -> nn.Module:
147+
"""
148+
Reverse `replace_all_linear_with_quantized_linear`
149+
QuantizedLinear --> Linear
150+
"""
151+
for name, module in model.named_modules():
152+
if isinstance(module, QuantizedLinear):
153+
linear = nn.Linear(
154+
in_features=module.in_features,
155+
out_features=module.out_features,
156+
bias=module.bias is not None,
157+
)
158+
linear.weight = module.weight
159+
linear.bias = module.bias
160+
161+
attribute_name = name.rsplit(".", 1)[-1]
162+
parent_of_module = get_layer_parent_by_name(model, name)
163+
setattr(parent_of_module, attribute_name, linear)
164+
165+
return model
166+
167+
168+
def make_custom_quantizer(quant_dtype, range_setting_weight=None):
169+
"""
170+
A custom quantizer which uses either the MSE or manual observer, depending
171+
on the weight range setting method provided.
172+
"""
173+
quantizer = make_quantizer(
174+
quant_dtype=quant_dtype,
175+
per_channel_conv=True,
176+
per_channel_linear=True,
177+
act_observer=MinMaxObserver,
178+
)
179+
if range_setting_weight in ("mse", "activation_loss"):
180+
if range_setting_weight == "mse":
181+
observer = PerChannelMSEObserver.with_args(**{"steps": 200, "use_mse": True})
182+
else:
183+
observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12})
184+
weight_dtype = (
185+
torch.int4
186+
if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block)
187+
else torch.int8
188+
)
189+
per_channel_q_config = quantizer.default_quant_config.quant_config
190+
weight_qspec = QuantizationSpec(
191+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
192+
quant_min=(
193+
-7
194+
if weight_dtype == torch.int4
195+
else torch.iinfo(weight_dtype).min + 1
196+
),
197+
quant_max=(
198+
7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max
199+
),
200+
qscheme=torch.per_channel_symmetric,
201+
ch_axis=0,
202+
observer_or_fake_quant_ctr=observer,
203+
)
204+
quantizer.default_quant_config.per_channel_quant_config = (
205+
QuantizationConfig(
206+
input_activation=per_channel_q_config.input_activation,
207+
output_activation=per_channel_q_config.output_activation,
208+
weight=weight_qspec,
209+
bias=_derived_bias_quant_spec,
210+
)
211+
)
212+
213+
return quantizer
214+
215+
216+
def compute_scales(model, data, num_points=100, weight_bits=4, activation_bits=16):
217+
"""
218+
Compute scales for weight quantization using activation loss range setting
219+
Uses function from Quanty
220+
1. Peform module swap
221+
2. Apply method from Quanty to compute optimal scales
222+
3. Save scales in dictionary
223+
4. Undo module swap
224+
"""
225+
recipe = QuantizationRecipe(
226+
weight_bits=weight_bits,
227+
weight_quantization=True,
228+
dynamic_weights=False,
229+
weight_group_size="per_channel",
230+
activation_bits=activation_bits,
231+
activation_quantization=True,
232+
activation_group_size="per_tensor",
233+
input_quantization=True,
234+
output_quantization=True,
235+
dynamic_activations=False,
236+
)
237+
238+
quantized_model = quantize_module_swap(model, recipe)
239+
240+
set_weight_range_activation_loss(quantized_model, data, 1, num_points) # batch_size = 1
241+
scale_dict = dict()
242+
for name, module in quantized_model.named_modules():
243+
if isinstance(module, QuantizedLinear):
244+
scale_dict[name] = module.weight_scale.clone().detach().to(device=model.device)
245+
246+
reverse_quantize_module_swap(model)
247+
248+
return scale_dict
249+
250+
251+
def set_scales(model, scale_dict, num_heads=32, dim=2048):
252+
"""
253+
Given a prepared model with manual observers inserted after weights, set scales
254+
manually. This is specific to Llama architecture, prepared as in the HTP flow
255+
(For example, we must separate scales because of splitting attention heads)
256+
"""
257+
head_dim = dim // num_heads
258+
for node in model.graph.nodes:
259+
if node.op == "get_attr":
260+
l = node.target.split(".")
261+
if len(l) > 3 and l[-3] in ("wq_sha", "wk_sha", "wv_sha"):
262+
shorter_name = l[-3][:2]
263+
key = ".".join(["model"] + l[:-3] + [shorter_name])
264+
observer_name = str(list(node.users.keys())[0])
265+
observer = getattr(model, observer_name)
266+
i = int(l[-2])
267+
observer.set_scale(scale_dict[key][head_dim*i:head_dim*(i + 1), :], device=model.device)
268+
elif len(l) > 1 and l[-2] in ("wo_sha", "w1_conv", "w2_conv", "w3_conv"):
269+
shorter_name = l[-2][:2]
270+
key = ".".join(["model"] + l[:-2] + [shorter_name])
271+
observer_name = str(list(node.users.keys())[0])
272+
observer = getattr(model, observer_name)
273+
observer.set_scale(scale_dict[key], model.device)

0 commit comments

Comments
 (0)