Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
annotate_matmul_16a8w,
)

from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
PerChannelParamObserver,
)
from executorch.backends.qualcomm.quantizer.qconfig import (
_derived_bias_quant_spec,
QuantizationConfig,
)

from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d

Expand Down Expand Up @@ -47,6 +55,8 @@

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import QuantizationSpec


sys.setrecursionlimit(4096)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -78,6 +88,33 @@ def forward(
return self.model.forward(tokens, self.atten_mask)


def add_mse_weight_observer(quant_dtype, quantizer):
weight_dtype = (
torch.int4
if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block)
else torch.int8
)
per_channel_q_config = quantizer.default_quant_config.quant_config
weight_qspec = QuantizationSpec(
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
quant_min=(
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
),
quant_max=(7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max),
qscheme=torch.per_channel_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
**{"steps": 200, "use_mse": True}
),
)
quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig(
input_activation=per_channel_q_config.input_activation,
output_activation=per_channel_q_config.output_activation,
weight=weight_qspec,
bias=_derived_bias_quant_spec,
)


def gen_eval_wrapper(model_name, args):
tokenizer = get_tokenizer(args.tokenizer_path)
with open(args.params) as f:
Expand Down Expand Up @@ -142,13 +179,13 @@ def permute(w, heads):
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.bfloat16)
model.to(dtype=torch.float)
model.to(device=args.device)

tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.bfloat16)
atten_mask = atten_mask.to(dtype=torch.float)
inputs = (tokens, atten_mask)

if args.embedding_quantize:
Expand All @@ -174,7 +211,8 @@ def permute(w, heads):
)
quantizer.add_custom_quant_annotations(custom_annotations)

model.has_quant_io = True
if args.range_setting == "mse_weight":
add_mse_weight_observer(quant_dtype, quantizer)

with torch.no_grad():
model = torch.export.export(model, inputs, strict=True).module()
Expand Down Expand Up @@ -245,6 +283,23 @@ def main() -> None:
torch.manual_seed(seed)
modelname = "llama2"
parser = build_args_parser()
parser.add_argument(
"-P",
"--ptq",
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.",
type=str,
)
parser.add_argument(
"--range_setting",
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
type=str,
)
parser.add_argument(
"--limit",
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
type=str,
)

args = parser.parse_args()
args.llama_model = "llama3_2"
# Overrides this arg, because evaluation requires full logits.
Expand All @@ -257,15 +312,9 @@ def main() -> None:
args.use_kv_cache = False
args.prefill_ar_len = args.max_seq_length

# To do fewer samples for faster evaluation
args.limit = 0.1
# args.samples = {'wikitext': list(range(1))}

args.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(args.device)

args.ptq = "8a8w"

eval_llama(modelname, args)


Expand Down
Loading