Skip to content

Commit bf82b26

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
AWQ for Qualcomm LlamaModels
Differential Revision: D81645608
1 parent dbac09c commit bf82b26

File tree

3 files changed

+172
-0
lines changed

3 files changed

+172
-0
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ python_library(
6666
],
6767
)
6868

69+
python_library(
70+
name = "awq",
71+
srcs = [
72+
"awq.py",
73+
],
74+
deps = [
75+
"//caffe2:torch",
76+
":llama_lib",
77+
],
78+
)
79+
6980
python_binary(
7081
name = "llama",
7182
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -89,6 +100,7 @@ python_binary(
89100
"//executorch/examples/models/llama:eval_library",
90101
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
91102
"fbsource//third-party/pypi/lm-eval:lm-eval",
103+
":awq"
92104
],
93105
)
94106

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import functools
2+
from collections import defaultdict
3+
4+
import torch
5+
import torch.nn as nn
6+
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import LlamaModel
7+
from tqdm import tqdm
8+
9+
10+
def fake_quantize_tensor(w, n_bit=4):
11+
# symmetric quantization
12+
max_val = w.abs().amax(dim=1, keepdim=True)
13+
max_val = max_val.clamp(min=1e-5)
14+
max_int = 2 ** (n_bit - 1) - 1
15+
min_int = -(2 ** (n_bit - 1) - 1)
16+
scales = max_val / max_int
17+
18+
w = (torch.clamp(torch.round(w / scales), min_int, max_int)) * scales
19+
return w
20+
21+
22+
def compute_and_apply_scale(prev_op, layers, inp, block, **kwargs):
23+
x = inp
24+
# w: co, ci
25+
# x: n, ci
26+
with torch.no_grad():
27+
org_out = block(x, **kwargs)
28+
if isinstance(org_out, tuple):
29+
org_out = org_out[0]
30+
31+
x_max = x.abs().view(-1, x.shape[-1]).mean(0)
32+
33+
best_error = float("inf")
34+
best_scales = None
35+
36+
n_grid = 20
37+
38+
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
39+
for ratio in range(n_grid):
40+
ratio = ratio * 1 / n_grid
41+
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
42+
scales = scales / (scales.max() * scales.min()).sqrt()
43+
for layer in layers:
44+
layer.weight.mul_(scales.view(1, -1))
45+
layer.weight.data = fake_quantize_tensor(layer.weight.data).detach() / (
46+
scales.view(1, -1)
47+
)
48+
49+
out = block(x, **kwargs)
50+
if isinstance(out, tuple):
51+
out = out[0]
52+
53+
loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
54+
is_best = loss < best_error
55+
if is_best:
56+
best_error = loss
57+
best_scales = scales
58+
block.load_state_dict(org_sd)
59+
60+
best_scales = best_scales.view(-1)
61+
62+
# Apply scales to previous
63+
scales = best_scales.detach()
64+
if isinstance(prev_op, nn.Linear):
65+
prev_op.weight[-scales.size(0) :].div_(scales.view(-1, 1))
66+
if prev_op.bias is not None:
67+
prev_op.bias.div_(scales.view(-1))
68+
elif isinstance(prev_op, nn.RMSNorm):
69+
prev_op.weight.div_(scales)
70+
if hasattr(prev_op, "bias") and prev_op.bias is not None:
71+
prev_op.bias.div_(scales)
72+
# Apply scales to layers
73+
for layer in layers:
74+
layer.weight.mul_(scales.view(1, -1))
75+
76+
77+
def apply_awq(model: LlamaModel, w_bit=4):
78+
"""Implement AWQ from scratch... but in a way that is applicable to the Qualcomm LlamaModel definition."""
79+
with torch.no_grad():
80+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
81+
# tokens = get_dataset() will add this later...
82+
hidden_states = model.tok_embeddings(tokens)
83+
84+
for i in tqdm(range(len(model.layers)), desc="Running AWQ..."):
85+
"""Solving AWQ layer by layer. In each decoder layer, we apply scales at four points:
86+
attention_norm -- [q, k, v]
87+
v -- o *not implemented yet
88+
ffn_norm -- [gate, up]
89+
up -- down
90+
"""
91+
decoder_layer = model.layers[i]
92+
named_linears = {
93+
name: m
94+
for name, m in decoder_layer.named_modules()
95+
if isinstance(m, nn.Linear)
96+
}
97+
98+
def cache_input_hook(module, x, y, name, feat_dict):
99+
x = x[0]
100+
x = x.detach()
101+
feat_dict[name].append(x)
102+
103+
input_feat = defaultdict(list)
104+
handles = []
105+
for name in named_linears:
106+
handles.append(
107+
named_linears[name].register_forward_hook(
108+
functools.partial(
109+
cache_input_hook, name=name, feat_dict=input_feat
110+
)
111+
)
112+
)
113+
# get output as next layer's input
114+
kwargs = {
115+
"freqs_cos": model.freqs_cos,
116+
"freqs_sin": model.freqs_sin,
117+
"atten_mask": atten_mask,
118+
"k_caches": None,
119+
"v_caches": None,
120+
}
121+
hidden_states, _, _ = decoder_layer(hidden_states, **kwargs)
122+
for h in handles:
123+
h.remove()
124+
input_feat = {
125+
k: torch.cat(v, dim=0) for k, v in input_feat.items()
126+
} # multi-input?
127+
128+
compute_and_apply_scale(
129+
prev_op=decoder_layer.attention_norm,
130+
layers=[
131+
decoder_layer.attention.wq,
132+
decoder_layer.attention.wk,
133+
decoder_layer.attention.wv,
134+
],
135+
inp=input_feat["attention.wq"],
136+
block=decoder_layer.attention,
137+
**kwargs
138+
)
139+
# Need to add v--o. Technical difficulty due to GQA. But apparently this has the least impact anyway
140+
compute_and_apply_scale(
141+
prev_op=decoder_layer.ffn_norm,
142+
layers=[decoder_layer.feed_forward.w1, decoder_layer.feed_forward.w3],
143+
inp=input_feat["feed_forward.w1"],
144+
block=decoder_layer.feed_forward,
145+
)
146+
compute_and_apply_scale(
147+
prev_op=decoder_layer.feed_forward.w3,
148+
layers=[decoder_layer.feed_forward.w2],
149+
inp=input_feat["feed_forward.w2"],
150+
block=decoder_layer.feed_forward.w2,
151+
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from executorch.examples.models.llama.source_transformation.quantize import (
4040
get_quant_embedding_transform,
4141
)
42+
from executorch.examples.qualcomm.oss_scripts.llama.awq import apply_awq
4243

4344
from executorch.examples.qualcomm.oss_scripts.llama.decoder_utils import calibrate
4445

@@ -170,6 +171,9 @@ def permute(w, heads):
170171
)
171172
logging.info("Applied SpinQuant to the model")
172173

174+
if args.awq:
175+
apply_awq(model)
176+
173177
if args.range_setting == "mse_with_act_loss":
174178
wrapped_model = WrappedLlamaModel(
175179
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -331,6 +335,11 @@ def main() -> None:
331335
help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations",
332336
action="store_true",
333337
)
338+
parser.add_argument(
339+
"--awq",
340+
help="Apply AWQ to the model",
341+
action="store_true",
342+
)
334343
parser.add_argument(
335344
"--fraction",
336345
help="the fraction of examples per task (only use this for testing)",

0 commit comments

Comments
 (0)