6666logging .getLogger ().setLevel (logging .INFO )
6767
6868
69+ def smart_mask_updator (atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches ):
70+ for i , k_cache in enumerate (k_caches ):
71+ k_cache [:, :, pos ] = new_k_caches [i ][:, :, 0 ]
72+
73+ for i , v_cache in enumerate (v_caches ):
74+ v_cache [:, pos , :] = new_v_caches [i ]
75+
76+ atten_mask [0 ][pos ] = 0
77+ pos += 1
78+ return (atten_mask , pos , k_caches , v_caches )
79+
80+
81+ def shift_pointer_updator (
82+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
83+ ):
84+ k_caches = [
85+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
86+ for i , k_cache in enumerate (k_caches )
87+ ]
88+ v_caches = [
89+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
90+ for i , v_cache in enumerate (v_caches )
91+ ]
92+
93+ pos += 1
94+ atten_mask [0 ][- pos - 1 ] = 0
95+ return (atten_mask , pos , k_caches , v_caches )
96+
97+
6998def _kv_calibrate (
7099 example_inputs ,
71100 user_prompts ,
72101 module : torch .fx .GraphModule ,
73102 tokenizer_model_path = "tokenizer.model" ,
74103 max_seq_len = 512 ,
104+ updator = smart_mask_updator ,
75105):
76106 sp_model = get_tokenizer (tokenizer_model_path )
77107 _ , atten_mask , _ , k_caches , v_caches = example_inputs
@@ -92,17 +122,9 @@ def _kv_calibrate(
92122 * k_caches ,
93123 * v_caches ,
94124 )
95- k_caches = [
96- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
97- for i , k_cache in enumerate (k_caches )
98- ]
99- v_caches = [
100- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
101- for i , v_cache in enumerate (v_caches )
102- ]
103-
104- pos += 1
105- atten_mask [0 ][- pos - 1 ] = 0
125+ atten_mask , pos , k_caches , v_caches = updator (
126+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
127+ )
106128 if pos >= len (token_list ):
107129 token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
108130
@@ -153,6 +175,7 @@ def calibrate(
153175 module : torch .fx .GraphModule ,
154176 tokenizer_model_path = "tokenizer.model" ,
155177 max_seq_len = 512 ,
178+ kv_updator = smart_mask_updator ,
156179):
157180 if len (example_inputs ) == 2 :
158181 _prefill_calibrate (
@@ -169,6 +192,7 @@ def calibrate(
169192 module ,
170193 tokenizer_model_path ,
171194 max_seq_len ,
195+ updator = kv_updator ,
172196 )
173197 else :
174198 raise RuntimeError ("Get wrong inputs" )
@@ -298,13 +322,15 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
298322 self .llama_model , self .inputs , strict = True
299323 ).module ()
300324 fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
325+
301326 logging .info ("Quantizing the model..." )
302327 calibrate (
303328 self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
304329 args .prompt ,
305330 fx_graph_module ,
306331 tokenizer_model_path = args .tokenizer_model ,
307332 max_seq_len = self .llama_meta ["get_max_seq_len" ],
333+ kv_updator = args .kv_updator ,
308334 )
309335
310336 self .llama_model = convert_pt2e (fx_graph_module )
@@ -316,6 +342,7 @@ def lowering_modules(
316342 use_fp16 = False ,
317343 soc_model = QcomChipset .SM8650 ,
318344 num_sharding = 0 ,
345+ shared_buffer = False ,
319346 ):
320347 executorch_config = ExecutorchBackendConfig (
321348 # For shared buffer, user must pass the memory address
@@ -336,7 +363,7 @@ def lowering_modules(
336363 compiler_specs = generate_qnn_executorch_compiler_spec (
337364 soc_model = soc_model ,
338365 backend_options = backend_options ,
339- shared_buffer = False ,
366+ shared_buffer = shared_buffer ,
340367 )
341368 skip_node_op_set = {"llama.fallback.default" }
342369 partitioner = QnnPartitioner (
@@ -366,7 +393,7 @@ def lowering_modules(
366393 if num_sharding > 0 :
367394 update_spill_fill_size (edge_prog_mgr .exported_program ())
368395 exec_prog_mgr = edge_prog_mgr .to_executorch (config = executorch_config )
369- with open (f"{ work_space } /{ pte_filename } .pte" , "wb" ) as file :
396+ with open (f"{ work_space } /{ self . pte_filename } .pte" , "wb" ) as file :
370397 exec_prog_mgr .write_to_file (file )
371398
372399 def get_example_inputs (self , use_kv_cache = True ):
@@ -491,6 +518,7 @@ def compile(args, pte_filename):
491518 use_fp16 = use_fp16 ,
492519 soc_model = get_soc_to_chipset_map ()[args .model ],
493520 num_sharding = args .num_sharding ,
521+ shared_buffer = args .shared_buffer ,
494522 )
495523 quant_attrs = llama_instance_list [0 ].get_quant_attrs ()
496524 else :
@@ -525,7 +553,7 @@ def compile(args, pte_filename):
525553 generate_qnn_executorch_compiler_spec (
526554 soc_model = get_soc_to_chipset_map ()[args .model ],
527555 backend_options = backend_options ,
528- shared_buffer = True ,
556+ shared_buffer = args . shared_buffer ,
529557 multiple_graphs = True ,
530558 graph_name = graph_name ,
531559 )
@@ -697,6 +725,7 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
697725 f"--system_prompt '{ args .system_prompt } '" ,
698726 f"--logits_scale { quant_attrs ['scale' ]} " ,
699727 f"--logits_offset { quant_attrs ['zero_point' ]} " ,
728+ f"--kv_updator { 'SmartMask' if args .kv_updator == smart_mask_updator else 'ShiftPointer' } " ,
700729 ]
701730 )
702731 runner_cmd = " " .join (
@@ -862,6 +891,14 @@ def main():
862891 type = int ,
863892 )
864893
894+ parser .add_argument (
895+ "--kv_updator" ,
896+ help = "Choose how to update kv cache during runtime" ,
897+ choices = ["smart_mask" , "shift_pointer" ],
898+ default = "smart_mask" ,
899+ type = str ,
900+ )
901+
865902 args = parser .parse_args ()
866903 if args .compile_only and args .pre_gen_pte :
867904 exit ("Cannot set both compile_only and pre_gen_pte as true" )
@@ -878,6 +915,14 @@ def main():
878915 else :
879916 raise RuntimeError (f"No such model_mode { args .model_mode } ." )
880917
918+ if args .kv_updator == "smart_mask" :
919+ args .shared_buffer = True
920+ args .kv_updator = smart_mask_updator
921+ elif args .kv_updator == "shift_pointer" :
922+ args .kv_updator = shift_pointer_updator
923+ else :
924+ exit (f"Using an unkown kv update { args .kv_updator } " )
925+
881926 if args .pre_gen_pte :
882927 quant_attrs = json .load (
883928 open (f"{ args .pre_gen_pte } /{ pte_filename } _quant_attrs.txt" )
0 commit comments