File tree Expand file tree Collapse file tree 3 files changed +8
-5
lines changed
backends/qualcomm/quantizer
examples/qualcomm/oss_scripts/llama3_2 Expand file tree Collapse file tree 3 files changed +8
-5
lines changed Original file line number Diff line number Diff line change 2222from torch .fx import Node
2323
2424
25- def annotate_matmul_16a8w (gm : torch .fx .GraphModule ) -> None :
25+ def annotate_matmul_16a8w (gm : torch .fx .GraphModule , traverse_input1 = True ) -> None :
2626 """
2727 This function is specific for matmul op 16a8w.
2828 """
@@ -99,7 +99,8 @@ def annotate_matmul_input1(node: Node):
9999 for node in gm .graph .nodes :
100100 if node .op == "call_function" and node .target == torch .ops .aten .matmul .default :
101101 annotate_matmul (node , quantization_config_16a8w )
102- annotate_matmul_input1 (node .args [1 ])
102+ if traverse_input1 :
103+ annotate_matmul_input1 (node .args [1 ])
103104
104105
105106def custom_annotate_llama_matmul_16a8w (gm : torch .fx .GraphModule ) -> None : # noqa: C901
Original file line number Diff line number Diff line change 88import json
99import logging
1010import os
11-
1211import sys
1312import time
13+ from functools import partial
1414from multiprocessing .connection import Client
1515
1616import torch
@@ -319,8 +319,10 @@ def compile(args):
319319
320320 if args .model_mode == "kv" :
321321 use_kv_cache = output_new_cache_only = True
322+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = True )
322323 elif args .model_mode == "batch_prefill" :
323324 use_kv_cache = output_new_cache_only = False
325+ matmul_annotate_func = partial (annotate_matmul_16a8w , traverse_input1 = False )
324326 elif args .model_mode == "hybrid" :
325327 raise NotImplementedError (
326328 f"model_mode { args .model_mode } is not implemented yet."
@@ -387,7 +389,7 @@ def compile(args):
387389 quant_dtype ,
388390 custom_annotations = (
389391 custom_annotate_llama_last_conv_16a8w ,
390- annotate_matmul_16a8w ,
392+ matmul_annotate_func ,
391393 ),
392394 )
393395 end_quantize_ts = time .time ()
Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def python_is_compatible():
137137 "timm==1.0.7" ,
138138 f"torchaudio==2.5.0.{ NIGHTLY_VERSION } " if USE_PYTORCH_NIGHTLY else "torchaudio" ,
139139 "torchsr==1.0.4" ,
140- "transformers==4.42.4" , # TODO update back to 4. 46.1 once the error is fixed
140+ "transformers==4.46.1" ,
141141]
142142
143143# pip packages needed for development.
You can’t perform that action at this time.
0 commit comments