Skip to content

Commit 0f86944

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Implemented range setting in QNN llama flow
Summary: `llama.py` now has the `--range_setting` flag, for which there are the options `mse_weight_only` and `mse_with_act_loss`. There is also an eval script for computing perplexity called `eval_llama_qnn.py`. This script also has a flag --quant_linear_only to only quantize linear/conv nodes, to run faster experiments. (for faster eval, try seq length 1024) Reviewed By: cccclai Differential Revision: D78127727
1 parent 4d7f9ca commit 0f86944

File tree

5 files changed

+454
-78
lines changed

5 files changed

+454
-78
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ python_library(
3434
],
3535
)
3636

37+
python_library(
38+
name = "range_setting_pt2e",
39+
srcs = [
40+
"range_setting_pt2e.py",
41+
],
42+
deps = [
43+
"//caffe2:torch",
44+
],
45+
)
46+
3747
python_binary(
3848
name = "llama",
3949
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -42,6 +52,7 @@ python_binary(
4252
],
4353
deps = [
4454
":llama_lib",
55+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
4556
],
4657
)
4758

@@ -55,6 +66,7 @@ python_binary(
5566
deps = [
5667
":llama_lib",
5768
"//executorch/examples/models/llama:eval_library",
69+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
5870
"fbsource//third-party/pypi/lm-eval:lm-eval",
5971
],
6072
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import copy
98
import json
109

1110
import logging
1211
import sys
1312

14-
from typing import List, Tuple
15-
1613
import torch
17-
import torch.nn as nn
14+
1815
from executorch.backends.qualcomm.quantizer.custom_annotation import (
1916
annotate_linear_16a8w_in_affine_layer,
2017
annotate_matmul_16a8w,
@@ -46,14 +43,18 @@
4643
LlamaModel,
4744
ModelArgs,
4845
)
49-
50-
from executorch.examples.qualcomm.utils import make_quantizer
46+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
47+
compute_scales,
48+
make_custom_quantizer,
49+
reverse_quantize_module_swap,
50+
set_scales,
51+
WrappedLlamaModel,
52+
)
5153

5254
from lm_eval.evaluator import simple_evaluate
5355

5456
from pytorch_tokenizers import get_tokenizer
5557

56-
from torchao.quantization.pt2e import MinMaxObserver
5758
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
5859
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5960

@@ -64,30 +65,6 @@
6465
logging.getLogger().setLevel(logging.INFO)
6566

6667

67-
class WrappedLlamaModel(nn.Module):
68-
def __init__(
69-
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
70-
):
71-
super(WrappedLlamaModel, self).__init__()
72-
self.model = model
73-
self.max_seq_len = max_seq_len
74-
self.use_kv_cache = use_kv_cache
75-
self.device = device
76-
self.atten_mask = atten_mask
77-
78-
def forward(
79-
self,
80-
tokens: torch.Tensor,
81-
*args,
82-
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
83-
# Pad input if necessary, since LlamaModel requires static shape
84-
if tokens.shape[1] != self.max_seq_len:
85-
tokens = torch.nn.functional.pad(
86-
tokens, (0, self.max_seq_len - tokens.shape[1])
87-
)
88-
return self.model.forward(tokens, self.atten_mask)
89-
90-
9168
def add_mse_weight_observer(quant_dtype, quantizer):
9269
weight_dtype = (
9370
torch.int4
@@ -115,24 +92,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
11592
)
11693

11794

118-
def gen_eval_wrapper(model_name, args):
119-
tokenizer = get_tokenizer(args.tokenizer_path)
95+
def prepare_model(model_name, args):
12096
with open(args.params) as f:
121-
kv_config = ModelArgs(**json.load(f))
97+
prefill_config = ModelArgs(**json.load(f))
12298
# TODO: support batch inputs if necessary
123-
kv_config.max_batch_size = 1
124-
kv_config.max_seq_len = args.max_seq_length
125-
kv_config.use_kv_cache = True
126-
127-
prefill_config = copy.copy(kv_config)
99+
prefill_config.max_batch_size = 1
128100
prefill_config.max_seq_len = args.max_seq_length
129-
prefill_config.use_kv_cache = (
130-
False if args.max_seq_length == args.prefill_ar_len else True
131-
)
132-
config = prefill_config
101+
prefill_config.use_kv_cache = False
133102
use_i64_token = args.embedding_quantize is not None
134103
model = LlamaModel(
135-
config,
104+
prefill_config,
136105
ar_len=args.prefill_ar_len,
137106
output_new_cache_only=True,
138107
output_cache=False,
@@ -173,57 +142,72 @@ def permute(w, heads):
173142
if "model" in state_dict:
174143
state_dict = state_dict["model"]
175144

145+
# TODO: use dtype of model checkpoint
146+
model = model.to(device=args.device, dtype=torch.float)
147+
inputs = model.get_example_inputs(use_kv_cache=False)
148+
tokens, atten_mask = inputs
149+
150+
scales_state_dict = {}
151+
if args.range_setting == "mse_with_act_loss":
152+
wrapped_model = WrappedLlamaModel(
153+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
154+
)
155+
act_bits, weight_bits = {
156+
"8a8w": (8, 8),
157+
"16a4w": (16, 4),
158+
"16a4w_block": (16, 4),
159+
}[args.ptq]
160+
scales_state_dict = compute_scales(
161+
wrapped_model, tokens, weight_bits, act_bits, 1600
162+
)
163+
torch.save(scales_state_dict, "scales_state_dict.pth")
164+
logging.info("Saved scales to scales_state_dict.pth!")
165+
reverse_quantize_module_swap(wrapped_model)
166+
176167
for layer in model.layers:
177168
if getattr(layer.attention, "prepare_sha", None):
178169
layer.attention.prepare_sha()
179170
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
180171
layer.feed_forward.prepare_feedfoward_conv()
181-
182-
model.to(dtype=torch.float)
183-
model.to(device=args.device)
184-
185-
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
186-
tokens = tokens.to(device=args.device)
187-
atten_mask = atten_mask.to(device=args.device)
188-
atten_mask = atten_mask.to(dtype=torch.float)
189-
inputs = (tokens, atten_mask)
190-
191172
if args.embedding_quantize:
192173
model = get_quant_embedding_transform(
193174
embedding_quantize=args.embedding_quantize
194175
)(model)
195176

196177
model = convert_linear_to_conv2d(model)
178+
return model, prefill_config, inputs, scales_state_dict
179+
197180

198-
if args.ptq:
181+
def gen_eval_wrapper(model_name, args):
182+
tokenizer = get_tokenizer(args.tokenizer_path)
183+
model, config, inputs, scales_state_dict = prepare_model(model_name, args)
184+
tokens, atten_mask = inputs
185+
use_i64_token = args.embedding_quantize is not None
186+
187+
if args.ptq is not None:
199188
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
200189

201190
custom_annotations = (annotate_matmul_16a8w,)
202191
if args.llama_model == "stories110m":
203192
custom_annotations = custom_annotations + (
204193
annotate_linear_16a8w_in_affine_layer,
205194
)
206-
quantizer = make_quantizer(
207-
quant_dtype=quant_dtype,
208-
per_channel_conv=True,
209-
per_channel_linear=True,
210-
act_observer=MinMaxObserver,
211-
)
212-
quantizer.add_custom_quant_annotations(custom_annotations)
213195

214-
if args.range_setting == "mse_weight":
215-
add_mse_weight_observer(quant_dtype, quantizer)
196+
quantizer = make_custom_quantizer(
197+
quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only
198+
)
216199

217200
with torch.no_grad():
201+
logging.info("Starting export...")
218202
model = torch.export.export(model, inputs, strict=True).module()
219203
if quant_dtype == QuantDtype.use_16a4w_block:
220204
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
221205
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
222206
quantizer.set_block_size_map(block_size_map)
223-
207+
logging.info("Finished export, adding observers (prepare_pt2e)...")
224208
model = prepare_pt2e(model, quantizer)
225209

226-
logging.info("Quantizing the model...")
210+
logging.info("Observers added, starting calibration...")
227211

228212
calibrate(
229213
inputs,
@@ -236,7 +220,24 @@ def permute(w, heads):
236220
use_i64_token=use_i64_token,
237221
)
238222

223+
if args.range_setting == "mse_with_act_loss":
224+
# scales_state_dict = torch.load("scales_state_dict.pth")
225+
set_scales(model, scales_state_dict, config.head_dim)
226+
227+
logging.info("Quantizing the model...")
239228
model = convert_pt2e(model)
229+
logging.info("Quantization complete! Here is some sample generated text:")
230+
231+
calibrate(
232+
inputs,
233+
"Could you tell me about Facebook?",
234+
model,
235+
tokenizer=tokenizer,
236+
ar_len=args.prefill_ar_len,
237+
max_seq_len=args.max_seq_len,
238+
kv_updater=None,
239+
use_i64_token=use_i64_token,
240+
)
240241

241242
model = WrappedLlamaModel(
242243
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -248,7 +249,7 @@ def permute(w, heads):
248249
max_seq_length=args.calibration_seq_length,
249250
use_kv_cache=args.use_kv_cache,
250251
generate_full_logits=args.generate_full_logits,
251-
enable_dynamic_shape=args.enable_dynamic_shape,
252+
enable_dynamic_shape=False,
252253
)
253254

254255

@@ -271,6 +272,7 @@ def eval_llama(
271272
model=eval_wrapper,
272273
tasks=args.tasks,
273274
num_fewshot=args.num_fewshot,
275+
limit=args.fraction,
274276
)
275277

276278
for task, res in eval_results["results"].items():
@@ -290,9 +292,19 @@ def main() -> None:
290292
)
291293
parser.add_argument(
292294
"--range_setting",
293-
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
295+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
294296
type=str,
295297
)
298+
parser.add_argument(
299+
"--fraction",
300+
help="the fraction of examples per task (only use this for testing)",
301+
type=float,
302+
)
303+
parser.add_argument(
304+
"--quant_linear_only",
305+
help="if you select this option we quantize linear layers only",
306+
action="store_true",
307+
)
296308

297309
args = parser.parse_args()
298310
args.llama_model = "llama3_2"

0 commit comments

Comments
 (0)