Skip to content

Commit e5e5dab

Browse files
authored
Implemented range setting in QNN llama flow
Differential Revision: D78127727 Pull Request resolved: #12377
1 parent 706490b commit e5e5dab

File tree

5 files changed

+518
-78
lines changed

5 files changed

+518
-78
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ python_library(
3434
],
3535
)
3636

37+
python_library(
38+
name = "range_setting_pt2e",
39+
srcs = [
40+
"range_setting_pt2e.py",
41+
],
42+
deps = [
43+
"//caffe2:torch",
44+
],
45+
)
46+
3747
python_binary(
3848
name = "llama",
3949
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -42,6 +52,7 @@ python_binary(
4252
],
4353
deps = [
4454
":llama_lib",
55+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
4556
],
4657
)
4758

@@ -55,6 +66,7 @@ python_binary(
5566
deps = [
5667
":llama_lib",
5768
"//executorch/examples/models/llama:eval_library",
69+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
5870
"fbsource//third-party/pypi/lm-eval:lm-eval",
5971
],
6072
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 105 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import copy
98
import json
109

1110
import logging
1211
import sys
13-
14-
from typing import List, Tuple
12+
import types
1513

1614
import torch
17-
import torch.nn as nn
15+
1816
from executorch.backends.qualcomm.quantizer.custom_annotation import (
1917
annotate_linear_16a8w_in_affine_layer,
2018
annotate_matmul_16a8w,
@@ -46,14 +44,19 @@
4644
LlamaModel,
4745
ModelArgs,
4846
)
49-
50-
from executorch.examples.qualcomm.utils import make_quantizer
47+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
48+
compute_scales,
49+
make_custom_quantizer,
50+
reverse_quantize_module_swap,
51+
set_scales,
52+
WrappedLlamaModel,
53+
)
5154

5255
from lm_eval.evaluator import simple_evaluate
5356

5457
from pytorch_tokenizers import get_tokenizer
58+
from torchao.prototype.spinquant import apply_spinquant
5559

56-
from torchao.quantization.pt2e import MinMaxObserver
5760
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
5861
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5962

@@ -64,30 +67,6 @@
6467
logging.getLogger().setLevel(logging.INFO)
6568

6669

67-
class WrappedLlamaModel(nn.Module):
68-
def __init__(
69-
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
70-
):
71-
super(WrappedLlamaModel, self).__init__()
72-
self.model = model
73-
self.max_seq_len = max_seq_len
74-
self.use_kv_cache = use_kv_cache
75-
self.device = device
76-
self.atten_mask = atten_mask
77-
78-
def forward(
79-
self,
80-
tokens: torch.Tensor,
81-
*args,
82-
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
83-
# Pad input if necessary, since LlamaModel requires static shape
84-
if tokens.shape[1] != self.max_seq_len:
85-
tokens = torch.nn.functional.pad(
86-
tokens, (0, self.max_seq_len - tokens.shape[1])
87-
)
88-
return self.model.forward(tokens, self.atten_mask)
89-
90-
9170
def add_mse_weight_observer(quant_dtype, quantizer):
9271
weight_dtype = (
9372
torch.int4
@@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
11594
)
11695

11796

118-
def gen_eval_wrapper(model_name, args):
119-
tokenizer = get_tokenizer(args.tokenizer_path)
97+
def prepare_model(model_name, args):
12098
with open(args.params) as f:
121-
kv_config = ModelArgs(**json.load(f))
99+
prefill_config = ModelArgs(**json.load(f))
122100
# TODO: support batch inputs if necessary
123-
kv_config.max_batch_size = 1
124-
kv_config.max_seq_len = args.max_seq_length
125-
kv_config.use_kv_cache = True
126-
127-
prefill_config = copy.copy(kv_config)
101+
prefill_config.max_batch_size = 1
128102
prefill_config.max_seq_len = args.max_seq_length
129-
prefill_config.use_kv_cache = (
130-
False if args.max_seq_length == args.prefill_ar_len else True
131-
)
132-
config = prefill_config
103+
prefill_config.use_kv_cache = False
133104
use_i64_token = args.embedding_quantize is not None
134105
model = LlamaModel(
135-
config,
106+
prefill_config,
136107
ar_len=args.prefill_ar_len,
137108
output_new_cache_only=True,
138109
output_cache=False,
@@ -173,57 +144,90 @@ def permute(w, heads):
173144
if "model" in state_dict:
174145
state_dict = state_dict["model"]
175146

147+
# TODO: use dtype of model checkpoint
148+
model = model.to(device=args.device, dtype=torch.float)
149+
inputs = model.get_example_inputs(use_kv_cache=False)
150+
tokens, atten_mask = inputs
151+
152+
scales_state_dict = {}
153+
if args.spinquant:
154+
config = types.SimpleNamespace(
155+
dim=prefill_config.dim,
156+
head_dim=prefill_config.dim // prefill_config.n_heads,
157+
n_local_heads=prefill_config.n_heads,
158+
intermediate_size=4 * prefill_config.dim,
159+
)
160+
model.config = config
161+
apply_spinquant(
162+
model,
163+
use_r1=True,
164+
use_r2=True,
165+
use_r4=False,
166+
pretrained_rotation_path=None,
167+
qkv_split=True,
168+
)
169+
logging.info("Applied SpinQuant to the model")
170+
171+
if args.range_setting == "mse_with_act_loss":
172+
wrapped_model = WrappedLlamaModel(
173+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
174+
)
175+
act_bits, weight_bits = {
176+
"8a8w": (8, 8),
177+
"16a4w": (16, 4),
178+
"16a4w_block": (16, 4),
179+
}[args.ptq]
180+
scales_state_dict = compute_scales(
181+
wrapped_model, tokens, weight_bits, act_bits, 1600
182+
)
183+
torch.save(scales_state_dict, "scales_state_dict.pth")
184+
logging.info("Saved scales to scales_state_dict.pth!")
185+
reverse_quantize_module_swap(wrapped_model)
186+
176187
for layer in model.layers:
177188
if getattr(layer.attention, "prepare_sha", None):
178189
layer.attention.prepare_sha()
179190
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
180191
layer.feed_forward.prepare_feedfoward_conv()
181-
182-
model.to(dtype=torch.float)
183-
model.to(device=args.device)
184-
185-
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
186-
tokens = tokens.to(device=args.device)
187-
atten_mask = atten_mask.to(device=args.device)
188-
atten_mask = atten_mask.to(dtype=torch.float)
189-
inputs = (tokens, atten_mask)
190-
191192
if args.embedding_quantize:
192193
model = get_quant_embedding_transform(
193194
embedding_quantize=args.embedding_quantize
194195
)(model)
195196

196197
model = convert_linear_to_conv2d(model)
198+
return model, prefill_config, inputs, scales_state_dict
199+
200+
201+
def gen_eval_wrapper(model_name, args):
202+
tokenizer = get_tokenizer(args.tokenizer_path)
203+
model, config, inputs, scales_state_dict = prepare_model(model_name, args)
204+
tokens, atten_mask = inputs
205+
use_i64_token = args.embedding_quantize is not None
197206

198-
if args.ptq:
207+
if args.ptq is not None:
199208
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
200209

201210
custom_annotations = (annotate_matmul_16a8w,)
202211
if args.llama_model == "stories110m":
203212
custom_annotations = custom_annotations + (
204213
annotate_linear_16a8w_in_affine_layer,
205214
)
206-
quantizer = make_quantizer(
207-
quant_dtype=quant_dtype,
208-
per_channel_conv=True,
209-
per_channel_linear=True,
210-
act_observer=MinMaxObserver,
211-
)
212-
quantizer.add_custom_quant_annotations(custom_annotations)
213215

214-
if args.range_setting == "mse_weight":
215-
add_mse_weight_observer(quant_dtype, quantizer)
216+
quantizer = make_custom_quantizer(
217+
quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only
218+
)
216219

217220
with torch.no_grad():
221+
logging.info("Starting export...")
218222
model = torch.export.export(model, inputs, strict=True).module()
219223
if quant_dtype == QuantDtype.use_16a4w_block:
220224
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
221225
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
222226
quantizer.set_block_size_map(block_size_map)
223-
227+
logging.info("Finished export, adding observers (prepare_pt2e)...")
224228
model = prepare_pt2e(model, quantizer)
225229

226-
logging.info("Quantizing the model...")
230+
logging.info("Observers added, starting calibration...")
227231

228232
calibrate(
229233
inputs,
@@ -236,7 +240,24 @@ def permute(w, heads):
236240
use_i64_token=use_i64_token,
237241
)
238242

243+
if args.range_setting == "mse_with_act_loss":
244+
# scales_state_dict = torch.load("scales_state_dict.pth")
245+
set_scales(model, scales_state_dict, config.head_dim)
246+
247+
logging.info("Quantizing the model...")
239248
model = convert_pt2e(model)
249+
logging.info("Quantization complete! Here is some sample generated text:")
250+
251+
calibrate(
252+
inputs,
253+
"Could you tell me about Facebook?",
254+
model,
255+
tokenizer=tokenizer,
256+
ar_len=args.prefill_ar_len,
257+
max_seq_len=args.max_seq_len,
258+
kv_updater=None,
259+
use_i64_token=use_i64_token,
260+
)
240261

241262
model = WrappedLlamaModel(
242263
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -248,7 +269,7 @@ def permute(w, heads):
248269
max_seq_length=args.calibration_seq_length,
249270
use_kv_cache=args.use_kv_cache,
250271
generate_full_logits=args.generate_full_logits,
251-
enable_dynamic_shape=args.enable_dynamic_shape,
272+
enable_dynamic_shape=False,
252273
)
253274

254275

@@ -271,6 +292,7 @@ def eval_llama(
271292
model=eval_wrapper,
272293
tasks=args.tasks,
273294
num_fewshot=args.num_fewshot,
295+
limit=args.fraction,
274296
)
275297

276298
for task, res in eval_results["results"].items():
@@ -290,9 +312,24 @@ def main() -> None:
290312
)
291313
parser.add_argument(
292314
"--range_setting",
293-
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
315+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
294316
type=str,
295317
)
318+
parser.add_argument(
319+
"--spinquant",
320+
help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations",
321+
action="store_true",
322+
)
323+
parser.add_argument(
324+
"--fraction",
325+
help="the fraction of examples per task (only use this for testing)",
326+
type=float,
327+
)
328+
parser.add_argument(
329+
"--quant_linear_only",
330+
help="if you select this option we quantize linear layers only",
331+
action="store_true",
332+
)
296333

297334
args = parser.parse_args()
298335
args.llama_model = "llama3_2"

0 commit comments

Comments
 (0)