Skip to content

Commit 4e19891

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Implemented range setting in QNN llama flow (#12377)
Summary: `llama.py` now has the `--range_setting` flag, for which there are the options `mse_weight_only` and `mse_with_act_loss`. There is also an eval script for computing perplexity called `eval_llama_qnn.py`. This script also has a flag --quant_linear_only to only quantize linear/conv nodes, to run faster experiments. (for faster eval, try seq length 1024) Reviewed By: cccclai Differential Revision: D78127727
1 parent 4df7223 commit 4e19891

File tree

5 files changed

+504
-78
lines changed

5 files changed

+504
-78
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ python_library(
3434
],
3535
)
3636

37+
python_library(
38+
name = "range_setting_pt2e",
39+
srcs = [
40+
"range_setting_pt2e.py",
41+
],
42+
deps = [
43+
"//caffe2:torch",
44+
],
45+
)
46+
3747
python_binary(
3848
name = "llama",
3949
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -42,6 +52,7 @@ python_binary(
4252
],
4353
deps = [
4454
":llama_lib",
55+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
4556
],
4657
)
4758

@@ -55,6 +66,7 @@ python_binary(
5566
deps = [
5667
":llama_lib",
5768
"//executorch/examples/models/llama:eval_library",
69+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
5870
"fbsource//third-party/pypi/lm-eval:lm-eval",
5971
],
6072
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 98 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import copy
98
import json
109

1110
import logging
1211
import sys
13-
14-
from typing import List, Tuple
12+
import types
1513

1614
import torch
17-
import torch.nn as nn
15+
1816
from executorch.backends.qualcomm.quantizer.custom_annotation import (
1917
annotate_linear_16a8w_in_affine_layer,
2018
annotate_matmul_16a8w,
@@ -46,14 +44,19 @@
4644
LlamaModel,
4745
ModelArgs,
4846
)
49-
50-
from executorch.examples.qualcomm.utils import make_quantizer
47+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
48+
compute_scales,
49+
make_custom_quantizer,
50+
reverse_quantize_module_swap,
51+
set_scales,
52+
WrappedLlamaModel,
53+
)
5154

5255
from lm_eval.evaluator import simple_evaluate
5356

5457
from pytorch_tokenizers import get_tokenizer
58+
from torchao.prototype.spinquant import apply_spinquant
5559

56-
from torchao.quantization.pt2e import MinMaxObserver
5760
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
5861
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5962

@@ -64,30 +67,6 @@
6467
logging.getLogger().setLevel(logging.INFO)
6568

6669

67-
class WrappedLlamaModel(nn.Module):
68-
def __init__(
69-
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
70-
):
71-
super(WrappedLlamaModel, self).__init__()
72-
self.model = model
73-
self.max_seq_len = max_seq_len
74-
self.use_kv_cache = use_kv_cache
75-
self.device = device
76-
self.atten_mask = atten_mask
77-
78-
def forward(
79-
self,
80-
tokens: torch.Tensor,
81-
*args,
82-
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
83-
# Pad input if necessary, since LlamaModel requires static shape
84-
if tokens.shape[1] != self.max_seq_len:
85-
tokens = torch.nn.functional.pad(
86-
tokens, (0, self.max_seq_len - tokens.shape[1])
87-
)
88-
return self.model.forward(tokens, self.atten_mask)
89-
90-
9170
def add_mse_weight_observer(quant_dtype, quantizer):
9271
weight_dtype = (
9372
torch.int4
@@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
11594
)
11695

11796

118-
def gen_eval_wrapper(model_name, args):
119-
tokenizer = get_tokenizer(args.tokenizer_path)
97+
def prepare_model(model_name, args):
12098
with open(args.params) as f:
121-
kv_config = ModelArgs(**json.load(f))
99+
prefill_config = ModelArgs(**json.load(f))
122100
# TODO: support batch inputs if necessary
123-
kv_config.max_batch_size = 1
124-
kv_config.max_seq_len = args.max_seq_length
125-
kv_config.use_kv_cache = True
126-
127-
prefill_config = copy.copy(kv_config)
101+
prefill_config.max_batch_size = 1
128102
prefill_config.max_seq_len = args.max_seq_length
129-
prefill_config.use_kv_cache = (
130-
False if args.max_seq_length == args.prefill_ar_len else True
131-
)
132-
config = prefill_config
103+
prefill_config.use_kv_cache = False
133104
use_i64_token = args.embedding_quantize is not None
134105
model = LlamaModel(
135-
config,
106+
prefill_config,
136107
ar_len=args.prefill_ar_len,
137108
output_new_cache_only=True,
138109
output_cache=False,
@@ -173,57 +144,83 @@ def permute(w, heads):
173144
if "model" in state_dict:
174145
state_dict = state_dict["model"]
175146

147+
# TODO: use dtype of model checkpoint
148+
model = model.to(device=args.device, dtype=torch.float)
149+
inputs = model.get_example_inputs(use_kv_cache=False)
150+
tokens, atten_mask = inputs
151+
152+
scales_state_dict = {}
153+
if args.spinquant:
154+
config = types.SimpleNamespace(
155+
dim=prefill_config.dim,
156+
head_dim=prefill_config.dim // prefill_config.n_heads,
157+
n_local_heads=prefill_config.n_heads,
158+
intermediate_size=4*prefill_config.dim
159+
)
160+
setattr(model, "config", config)
161+
apply_spinquant(model, use_r1=True, use_r2=True, use_r4=False, pretrained_rotation_path=None, qkv_split=True)
162+
logging.info("Applied SpinQuant to the model")
163+
164+
if args.range_setting == "mse_with_act_loss":
165+
wrapped_model = WrappedLlamaModel(
166+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
167+
)
168+
act_bits, weight_bits = {
169+
"8a8w": (8, 8),
170+
"16a4w": (16, 4),
171+
"16a4w_block": (16, 4),
172+
}[args.ptq]
173+
scales_state_dict = compute_scales(
174+
wrapped_model, tokens, weight_bits, act_bits, 1600
175+
)
176+
torch.save(scales_state_dict, "scales_state_dict.pth")
177+
logging.info("Saved scales to scales_state_dict.pth!")
178+
reverse_quantize_module_swap(wrapped_model)
179+
176180
for layer in model.layers:
177181
if getattr(layer.attention, "prepare_sha", None):
178182
layer.attention.prepare_sha()
179183
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
180184
layer.feed_forward.prepare_feedfoward_conv()
181-
182-
model.to(dtype=torch.float)
183-
model.to(device=args.device)
184-
185-
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
186-
tokens = tokens.to(device=args.device)
187-
atten_mask = atten_mask.to(device=args.device)
188-
atten_mask = atten_mask.to(dtype=torch.float)
189-
inputs = (tokens, atten_mask)
190-
191185
if args.embedding_quantize:
192186
model = get_quant_embedding_transform(
193187
embedding_quantize=args.embedding_quantize
194188
)(model)
195189

196190
model = convert_linear_to_conv2d(model)
191+
return model, prefill_config, inputs, scales_state_dict
192+
193+
194+
def gen_eval_wrapper(model_name, args):
195+
tokenizer = get_tokenizer(args.tokenizer_path)
196+
model, config, inputs, scales_state_dict = prepare_model(model_name, args)
197+
tokens, atten_mask = inputs
198+
use_i64_token = args.embedding_quantize is not None
197199

198-
if args.ptq:
200+
if args.ptq is not None:
199201
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
200202

201203
custom_annotations = (annotate_matmul_16a8w,)
202204
if args.llama_model == "stories110m":
203205
custom_annotations = custom_annotations + (
204206
annotate_linear_16a8w_in_affine_layer,
205207
)
206-
quantizer = make_quantizer(
207-
quant_dtype=quant_dtype,
208-
per_channel_conv=True,
209-
per_channel_linear=True,
210-
act_observer=MinMaxObserver,
211-
)
212-
quantizer.add_custom_quant_annotations(custom_annotations)
213208

214-
if args.range_setting == "mse_weight":
215-
add_mse_weight_observer(quant_dtype, quantizer)
209+
quantizer = make_custom_quantizer(
210+
quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only
211+
)
216212

217213
with torch.no_grad():
214+
logging.info("Starting export...")
218215
model = torch.export.export(model, inputs, strict=True).module()
219216
if quant_dtype == QuantDtype.use_16a4w_block:
220217
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
221218
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
222219
quantizer.set_block_size_map(block_size_map)
223-
220+
logging.info("Finished export, adding observers (prepare_pt2e)...")
224221
model = prepare_pt2e(model, quantizer)
225222

226-
logging.info("Quantizing the model...")
223+
logging.info("Observers added, starting calibration...")
227224

228225
calibrate(
229226
inputs,
@@ -236,7 +233,24 @@ def permute(w, heads):
236233
use_i64_token=use_i64_token,
237234
)
238235

236+
if args.range_setting == "mse_with_act_loss":
237+
# scales_state_dict = torch.load("scales_state_dict.pth")
238+
set_scales(model, scales_state_dict, config.head_dim)
239+
240+
logging.info("Quantizing the model...")
239241
model = convert_pt2e(model)
242+
logging.info("Quantization complete! Here is some sample generated text:")
243+
244+
calibrate(
245+
inputs,
246+
"Could you tell me about Facebook?",
247+
model,
248+
tokenizer=tokenizer,
249+
ar_len=args.prefill_ar_len,
250+
max_seq_len=args.max_seq_len,
251+
kv_updater=None,
252+
use_i64_token=use_i64_token,
253+
)
240254

241255
model = WrappedLlamaModel(
242256
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -248,7 +262,7 @@ def permute(w, heads):
248262
max_seq_length=args.calibration_seq_length,
249263
use_kv_cache=args.use_kv_cache,
250264
generate_full_logits=args.generate_full_logits,
251-
enable_dynamic_shape=args.enable_dynamic_shape,
265+
enable_dynamic_shape=False,
252266
)
253267

254268

@@ -271,6 +285,7 @@ def eval_llama(
271285
model=eval_wrapper,
272286
tasks=args.tasks,
273287
num_fewshot=args.num_fewshot,
288+
limit=args.fraction,
274289
)
275290

276291
for task, res in eval_results["results"].items():
@@ -290,9 +305,24 @@ def main() -> None:
290305
)
291306
parser.add_argument(
292307
"--range_setting",
293-
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
308+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
294309
type=str,
295310
)
311+
parser.add_argument(
312+
"--spinquant",
313+
help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations",
314+
action="store_true",
315+
)
316+
parser.add_argument(
317+
"--fraction",
318+
help="the fraction of examples per task (only use this for testing)",
319+
type=float,
320+
)
321+
parser.add_argument(
322+
"--quant_linear_only",
323+
help="if you select this option we quantize linear layers only",
324+
action="store_true",
325+
)
296326

297327
args = parser.parse_args()
298328
args.llama_model = "llama3_2"

0 commit comments

Comments
 (0)