File tree Expand file tree Collapse file tree 3 files changed +13
-5
lines changed
backends/qualcomm/quantizer
examples/qualcomm/oss_scripts/llama3_2 Expand file tree Collapse file tree 3 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 2222from torch .fx import Node
2323
2424
25- def annotate_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
25+ def annotate_matmul_16a8w (
26+ gm : torch .fx .GraphModule , traverse_input1 = True
27+ ) -> None : # noqa: C901
2628 """
2729 This function is specific for matmul op 16a8w.
2830 """
@@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node):
99101 for node in gm .graph .nodes :
100102 if node .op == "call_function" and node .target == torch .ops .aten .matmul .default :
101103 annotate_matmul (node , quantization_config_16a8w )
102- annotate_matmul_input1 (node .args [1 ])
104+ if traverse_input1 :
105+ annotate_matmul_input1 (node .args [1 ])
103106
104107
105108def custom_annotate_llama_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
Original file line number Diff line number Diff line change 88import json
99import logging
1010import os
11-
1211import sys
1312import time
13+ from functools import partial
1414from multiprocessing .connection import Client
1515
1616import torch
@@ -319,8 +319,10 @@ def compile(args):
319319
320320 if args .model_mode == "kv" :
321321 use_kv_cache = output_new_cache_only = True
322+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = True )
322323 elif args .model_mode == "batch_prefill" :
323324 use_kv_cache = output_new_cache_only = False
325+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = False )
324326 elif args .model_mode == "hybrid" :
325327 raise NotImplementedError (
326328 f"model_mode { args .model_mode } is not implemented yet."
@@ -385,7 +387,10 @@ def compile(args):
385387 start_quantize_ts = time .time ()
386388 single_llama .quantize (
387389 quant_dtype ,
388- custom_annotations = (annotate_matmul_16a8w ,),
390+ custom_annotations = (
391+ custom_annotate_llama_last_conv_16a8w ,
392+ matmul_annotate_func ,
393+ ),
389394 )
390395 end_quantize_ts = time .time ()
391396 logging .info (f"Time for quantizing: { end_quantize_ts - start_quantize_ts } " )
Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def python_is_compatible():
137137 "timm==1.0.7" ,
138138 f"torchaudio==2.5.0.{ NIGHTLY_VERSION } " if USE_PYTORCH_NIGHTLY else "torchaudio" ,
139139 "torchsr==1.0.4" ,
140- "transformers==4.42.4" , # TODO update back to 4. 46.1 once the error is fixed
140+ "transformers==4.46.1" ,
141141]
142142
143143# pip packages needed for development.
You can’t perform that action at this time.
0 commit comments