7272logging .getLogger ().setLevel (logging .INFO )
7373
7474
75+ def smart_mask_updator (atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches ):
76+ for i , k_cache in enumerate (k_caches ):
77+ k_cache [:, :, pos ] = new_k_caches [i ][:, :, 0 ]
78+
79+ for i , v_cache in enumerate (v_caches ):
80+ v_cache [:, pos , :] = new_v_caches [i ]
81+
82+ atten_mask [0 ][pos ] = 0
83+ pos += 1
84+ return (atten_mask , pos , k_caches , v_caches )
85+
86+
87+ def shift_pointer_updator (
88+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
89+ ):
90+ k_caches = [
91+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
92+ for i , k_cache in enumerate (k_caches )
93+ ]
94+ v_caches = [
95+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
96+ for i , v_cache in enumerate (v_caches )
97+ ]
98+
99+ pos += 1
100+ atten_mask [0 ][- pos - 1 ] = 0
101+ return (atten_mask , pos , k_caches , v_caches )
102+
103+
75104def _kv_calibrate (
76105 example_inputs ,
77106 user_prompts ,
78107 module : torch .fx .GraphModule ,
79108 tokenizer ,
80109 max_seq_len = 512 ,
110+ updator = smart_mask_updator ,
81111):
82112 _ , atten_mask , _ , k_caches , v_caches = example_inputs
83113
@@ -105,17 +135,9 @@ def _kv_calibrate(
105135 * k_caches ,
106136 * v_caches ,
107137 )
108- k_caches = [
109- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
110- for i , k_cache in enumerate (k_caches )
111- ]
112- v_caches = [
113- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
114- for i , v_cache in enumerate (v_caches )
115- ]
116-
117- pos += 1
118- atten_mask [0 ][- pos - 1 ] = 0
138+ atten_mask , pos , k_caches , v_caches = updator (
139+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
140+ )
119141 if pos >= len (token_list ):
120142 token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
121143
@@ -174,6 +196,7 @@ def calibrate(
174196 module : torch .fx .GraphModule ,
175197 tokenizer ,
176198 max_seq_len = 512 ,
199+ kv_updator = smart_mask_updator ,
177200):
178201 if len (example_inputs ) == 2 :
179202 _prefill_calibrate (
@@ -190,6 +213,7 @@ def calibrate(
190213 module ,
191214 tokenizer ,
192215 max_seq_len ,
216+ updator = kv_updator ,
193217 )
194218 else :
195219 raise RuntimeError ("Get wrong inputs" )
@@ -319,13 +343,15 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
319343 self .llama_model , self .inputs , strict = True
320344 ).module ()
321345 fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
346+
322347 logging .info ("Quantizing the model..." )
323348 calibrate (
324349 self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
325350 args .prompt ,
326351 fx_graph_module ,
327352 tokenizer = tokenizer ,
328353 max_seq_len = self .llama_meta ["get_max_seq_len" ],
354+ kv_updator = args .kv_updator ,
329355 )
330356
331357 self .llama_model = convert_pt2e (fx_graph_module )
@@ -337,6 +363,7 @@ def lowering_modules(
337363 use_fp16 = False ,
338364 soc_model = QcomChipset .SM8650 ,
339365 num_sharding = 0 ,
366+ shared_buffer = False ,
340367 ):
341368 executorch_config = ExecutorchBackendConfig (
342369 # For shared buffer, user must pass the memory address
@@ -357,7 +384,7 @@ def lowering_modules(
357384 compiler_specs = generate_qnn_executorch_compiler_spec (
358385 soc_model = soc_model ,
359386 backend_options = backend_options ,
360- shared_buffer = False ,
387+ shared_buffer = shared_buffer ,
361388 )
362389 skip_node_op_set = {"llama.fallback.default" }
363390 partitioner = QnnPartitioner (
@@ -530,6 +557,7 @@ def compile(args, pte_filename, tokenizer):
530557 use_fp16 = use_fp16 ,
531558 soc_model = get_soc_to_chipset_map ()[args .model ],
532559 num_sharding = args .num_sharding ,
560+ shared_buffer = args .shared_buffer ,
533561 )
534562 quant_attrs = llama_instance_list [0 ].get_quant_attrs ()
535563 else :
@@ -564,7 +592,7 @@ def compile(args, pte_filename, tokenizer):
564592 generate_qnn_executorch_compiler_spec (
565593 soc_model = get_soc_to_chipset_map ()[args .model ],
566594 backend_options = backend_options ,
567- shared_buffer = True ,
595+ shared_buffer = args . shared_buffer ,
568596 multiple_graphs = True ,
569597 graph_name = graph_name ,
570598 )
@@ -736,6 +764,7 @@ def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_p
736764 f"--system_prompt '{ args .system_prompt } '" ,
737765 f"--logits_scale { quant_attrs ['scale' ]} " ,
738766 f"--logits_offset { quant_attrs ['zero_point' ]} " ,
767+ f"--kv_updator { 'SmartMask' if args .kv_updator == smart_mask_updator else 'ShiftPointer' } " ,
739768 ]
740769 )
741770 runner_cmd = " " .join (
@@ -907,6 +936,14 @@ def main():
907936 type = int ,
908937 )
909938
939+ parser .add_argument (
940+ "--kv_updator" ,
941+ help = "Choose how to update kv cache during runtime" ,
942+ choices = ["smart_mask" , "shift_pointer" ],
943+ default = "smart_mask" ,
944+ type = str ,
945+ )
946+
910947 args = parser .parse_args ()
911948 if args .compile_only and args .pre_gen_pte :
912949 exit ("Cannot set both compile_only and pre_gen_pte as true" )
@@ -941,6 +978,14 @@ def main():
941978 else :
942979 raise RuntimeError (f"Unknown llama_model: { args .llama_model } ." )
943980
981+ if args .kv_updator == "smart_mask" :
982+ args .shared_buffer = True
983+ args .kv_updator = smart_mask_updator
984+ elif args .kv_updator == "shift_pointer" :
985+ args .kv_updator = shift_pointer_updator
986+ else :
987+ exit (f"Using an unkown kv update { args .kv_updator } " )
988+
944989 if args .pre_gen_pte :
945990 quant_attrs = json .load (
946991 open (f"{ args .pre_gen_pte } /{ pte_filename } _quant_attrs.txt" )
0 commit comments