Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/qualcomm/oss_scripts/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ python_library(
],
)

python_library(
name = "range_setting_pt2e",
srcs = [
"range_setting_pt2e.py",
],
deps = [
"//caffe2:torch",
],
)

python_binary(
name = "llama",
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
Expand All @@ -42,6 +52,7 @@ python_binary(
],
deps = [
":llama_lib",
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
],
)

Expand All @@ -55,6 +66,7 @@ python_binary(
deps = [
":llama_lib",
"//executorch/examples/models/llama:eval_library",
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
"fbsource//third-party/pypi/lm-eval:lm-eval",
],
)
Expand Down
173 changes: 105 additions & 68 deletions examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
# LICENSE file in the root directory of this source tree.

import argparse
import copy
import json

import logging
import sys

from typing import List, Tuple
import types

import torch
import torch.nn as nn

from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_linear_16a8w_in_affine_layer,
annotate_matmul_16a8w,
Expand Down Expand Up @@ -46,14 +44,19 @@
LlamaModel,
ModelArgs,
)

from executorch.examples.qualcomm.utils import make_quantizer
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
compute_scales,
make_custom_quantizer,
reverse_quantize_module_swap,
set_scales,
WrappedLlamaModel,
)

from lm_eval.evaluator import simple_evaluate

from pytorch_tokenizers import get_tokenizer
from torchao.prototype.spinquant import apply_spinquant

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import QuantizationSpec

Expand All @@ -64,30 +67,6 @@
logging.getLogger().setLevel(logging.INFO)


class WrappedLlamaModel(nn.Module):
def __init__(
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
):
super(WrappedLlamaModel, self).__init__()
self.model = model
self.max_seq_len = max_seq_len
self.use_kv_cache = use_kv_cache
self.device = device
self.atten_mask = atten_mask

def forward(
self,
tokens: torch.Tensor,
*args,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
# Pad input if necessary, since LlamaModel requires static shape
if tokens.shape[1] != self.max_seq_len:
tokens = torch.nn.functional.pad(
tokens, (0, self.max_seq_len - tokens.shape[1])
)
return self.model.forward(tokens, self.atten_mask)


def add_mse_weight_observer(quant_dtype, quantizer):
weight_dtype = (
torch.int4
Expand Down Expand Up @@ -115,24 +94,16 @@ def add_mse_weight_observer(quant_dtype, quantizer):
)


def gen_eval_wrapper(model_name, args):
tokenizer = get_tokenizer(args.tokenizer_path)
def prepare_model(model_name, args):
with open(args.params) as f:
kv_config = ModelArgs(**json.load(f))
prefill_config = ModelArgs(**json.load(f))
# TODO: support batch inputs if necessary
kv_config.max_batch_size = 1
kv_config.max_seq_len = args.max_seq_length
kv_config.use_kv_cache = True

prefill_config = copy.copy(kv_config)
prefill_config.max_batch_size = 1
prefill_config.max_seq_len = args.max_seq_length
prefill_config.use_kv_cache = (
False if args.max_seq_length == args.prefill_ar_len else True
)
config = prefill_config
prefill_config.use_kv_cache = False
use_i64_token = args.embedding_quantize is not None
model = LlamaModel(
config,
prefill_config,
ar_len=args.prefill_ar_len,
output_new_cache_only=True,
output_cache=False,
Expand Down Expand Up @@ -173,57 +144,90 @@ def permute(w, heads):
if "model" in state_dict:
state_dict = state_dict["model"]

# TODO: use dtype of model checkpoint
model = model.to(device=args.device, dtype=torch.float)
inputs = model.get_example_inputs(use_kv_cache=False)
tokens, atten_mask = inputs

scales_state_dict = {}
if args.spinquant:
config = types.SimpleNamespace(
dim=prefill_config.dim,
head_dim=prefill_config.dim // prefill_config.n_heads,
n_local_heads=prefill_config.n_heads,
intermediate_size=4 * prefill_config.dim,
)
model.config = config
apply_spinquant(
model,
use_r1=True,
use_r2=True,
use_r4=False,
pretrained_rotation_path=None,
qkv_split=True,
)
logging.info("Applied SpinQuant to the model")

if args.range_setting == "mse_with_act_loss":
wrapped_model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
)
act_bits, weight_bits = {
"8a8w": (8, 8),
"16a4w": (16, 4),
"16a4w_block": (16, 4),
}[args.ptq]
scales_state_dict = compute_scales(
wrapped_model, tokens, weight_bits, act_bits, 1600
)
torch.save(scales_state_dict, "scales_state_dict.pth")
logging.info("Saved scales to scales_state_dict.pth!")
reverse_quantize_module_swap(wrapped_model)

for layer in model.layers:
if getattr(layer.attention, "prepare_sha", None):
layer.attention.prepare_sha()
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.float)
model.to(device=args.device)

tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.float)
inputs = (tokens, atten_mask)

if args.embedding_quantize:
model = get_quant_embedding_transform(
embedding_quantize=args.embedding_quantize
)(model)

model = convert_linear_to_conv2d(model)
return model, prefill_config, inputs, scales_state_dict


def gen_eval_wrapper(model_name, args):
tokenizer = get_tokenizer(args.tokenizer_path)
model, config, inputs, scales_state_dict = prepare_model(model_name, args)
tokens, atten_mask = inputs
use_i64_token = args.embedding_quantize is not None

if args.ptq:
if args.ptq is not None:
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")

custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)

if args.range_setting == "mse_weight":
add_mse_weight_observer(quant_dtype, quantizer)
quantizer = make_custom_quantizer(
quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only
)

with torch.no_grad():
logging.info("Starting export...")
model = torch.export.export(model, inputs, strict=True).module()
if quant_dtype == QuantDtype.use_16a4w_block:
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
quantizer.set_block_size_map(block_size_map)

logging.info("Finished export, adding observers (prepare_pt2e)...")
model = prepare_pt2e(model, quantizer)

logging.info("Quantizing the model...")
logging.info("Observers added, starting calibration...")

calibrate(
inputs,
Expand All @@ -236,7 +240,24 @@ def permute(w, heads):
use_i64_token=use_i64_token,
)

if args.range_setting == "mse_with_act_loss":
# scales_state_dict = torch.load("scales_state_dict.pth")
set_scales(model, scales_state_dict, config.head_dim)

logging.info("Quantizing the model...")
model = convert_pt2e(model)
logging.info("Quantization complete! Here is some sample generated text:")

calibrate(
inputs,
"Could you tell me about Facebook?",
model,
tokenizer=tokenizer,
ar_len=args.prefill_ar_len,
max_seq_len=args.max_seq_len,
kv_updater=None,
use_i64_token=use_i64_token,
)

model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
Expand All @@ -248,7 +269,7 @@ def permute(w, heads):
max_seq_length=args.calibration_seq_length,
use_kv_cache=args.use_kv_cache,
generate_full_logits=args.generate_full_logits,
enable_dynamic_shape=args.enable_dynamic_shape,
enable_dynamic_shape=False,
)


Expand All @@ -271,6 +292,7 @@ def eval_llama(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot,
limit=args.fraction,
)

for task, res in eval_results["results"].items():
Expand All @@ -290,9 +312,24 @@ def main() -> None:
)
parser.add_argument(
"--range_setting",
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
type=str,
)
parser.add_argument(
"--spinquant",
help="Apply SpinQuant (R1+R2) to the model. Uses random Hadamard matrices for rotations",
action="store_true",
)
parser.add_argument(
"--fraction",
help="the fraction of examples per task (only use this for testing)",
type=float,
)
parser.add_argument(
"--quant_linear_only",
help="if you select this option we quantize linear layers only",
action="store_true",
)

args = parser.parse_args()
args.llama_model = "llama3_2"
Expand Down
Loading
Loading