@@ -73,37 +73,49 @@ def _kv_calibrate(
7373 max_seq_len = 512 ,
7474):
7575 sp_model = get_tokenizer (tokenizer_model_path )
76- _ , atten_mask , _ , k_caches , v_caches = example_inputs
7776
7877 # TODO: change criteria & support batch inputs if necessary
79- pos = torch .tensor (0 , dtype = torch .int32 )
8078 max_cache_len = max_seq_len - 1
81- token_list = sp_model .encode (user_prompts , bos = True , eos = False )
8279
83- with torch .no_grad ():
84- while token_list [- 1 ] != sp_model .eos_id and pos < max_cache_len :
85- logits , new_k_caches , new_v_caches = module (
86- torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
87- atten_mask ,
88- torch .full ((1 , 1 ), pos ),
89- * k_caches ,
90- * v_caches ,
91- )
92- k_caches = [
93- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
94- for i , k_cache in enumerate (k_caches )
95- ]
96- v_caches = [
97- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
98- for i , v_cache in enumerate (v_caches )
99- ]
100-
101- pos += 1
102- atten_mask [0 ][- pos - 1 ] = 0
103- if pos >= len (token_list ):
104- token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
10580
106- print (f"calibration data:\n { sp_model .decode (token_list )} " )
81+ # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
82+
83+ user_token_list = [
84+ # what is the capital of the united states
85+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 6864 , 315 , 279 , 29292 , 5415 , 128009 , 128006 , 78191 , 128007 , 271 ],
86+ # what is 1 + 1
87+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 220 , 16 , 489 , 220 , 16 , 128009 , 128006 , 78191 , 128007 , 271 ],
88+ # what is the meaning of life
89+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 7438 , 315 , 2324 , 128009 , 128006 , 78191 , 128007 , 271 ],
90+ ]
91+
92+ for token_list in user_token_list :
93+ _ , atten_mask , _ , k_caches , v_caches = copy .deepcopy (example_inputs )
94+ pos = torch .tensor (0 , dtype = torch .int32 )
95+ with torch .no_grad ():
96+ while token_list [- 1 ] != sp_model .eos_id and pos < max_cache_len :
97+ logits , new_k_caches , new_v_caches = module (
98+ torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
99+ atten_mask ,
100+ torch .full ((1 , 1 ), pos ),
101+ * k_caches ,
102+ * v_caches ,
103+ )
104+ k_caches = [
105+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
106+ for i , k_cache in enumerate (k_caches )
107+ ]
108+ v_caches = [
109+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
110+ for i , v_cache in enumerate (v_caches )
111+ ]
112+
113+ pos += 1
114+ atten_mask [0 ][- pos - 1 ] = 0
115+ if pos >= len (token_list ):
116+ token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
117+
118+ logging .info (f"calibration data:\n { sp_model .decode (token_list )} " )
107119
108120
109121def _prefill_calibrate (
@@ -114,32 +126,44 @@ def _prefill_calibrate(
114126 max_seq_len = 512 ,
115127):
116128 sp_model = get_tokenizer (tokenizer_model_path )
117- _ , atten_mask = example_inputs
118129 max_cache_len = max_seq_len - 1
119130
120131 # TODO: change criteria & support batch inputs if necessary
121- token_list = sp_model .encode (user_prompts , bos = True , eos = False )
122- token_list = torch .tensor (token_list )[:max_cache_len ].reshape (1 , - 1 )
123- last_prompt_pos = token_list .numel ()
124- if last_prompt_pos < max_cache_len :
125- token_list = torch .cat (
126- [
127- token_list ,
128- torch .zeros ((1 , max_cache_len - last_prompt_pos ), dtype = torch .int32 ),
129- ],
130- dim = 1 ,
131- )
132- else :
133- token_list = token_list [:, :max_cache_len ]
134-
135- with torch .no_grad ():
136- logits , new_k_caches , new_v_caches = module (
137- token_list ,
138- atten_mask ,
139- )
140- predict = [torch .argmax (logits [:, last_prompt_pos - 1 ], dim = - 1 ).item ()]
132+
133+ # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
134+
135+ user_token_list = [
136+ # what is the capital of the united states
137+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 6864 , 315 , 279 , 29292 , 5415 , 128009 , 128006 , 78191 , 128007 , 271 ],
138+ # what is 1 + 1
139+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 220 , 16 , 489 , 220 , 16 , 128009 , 128006 , 78191 , 128007 , 271 ],
140+ # what is the meaning of life
141+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 7438 , 315 , 2324 , 128009 , 128006 , 78191 , 128007 , 271 ],
142+ ]
143+
144+ for token_list in user_token_list :
145+ _ , atten_mask = copy .deepcopy (example_inputs )
146+ token_list = torch .tensor (token_list )[:max_cache_len ].reshape (1 , - 1 )
147+ last_prompt_pos = token_list .numel ()
148+ if last_prompt_pos < max_cache_len :
149+ token_list = torch .cat (
150+ [
151+ token_list ,
152+ torch .zeros ((1 , max_cache_len - last_prompt_pos ), dtype = torch .int32 ),
153+ ],
154+ dim = 1 ,
155+ )
156+ else :
157+ token_list = token_list [:, :max_cache_len ]
141158
142- print (f"calibration data:\n { sp_model .decode (predict )} " )
159+ with torch .no_grad ():
160+ logits , new_k_caches , new_v_caches = module (
161+ token_list ,
162+ atten_mask ,
163+ )
164+ predict = [torch .argmax (logits [:, last_prompt_pos - 1 ], dim = - 1 ).item ()]
165+
166+ logging .info (f"calibration data:\n { sp_model .decode (predict )} " )
143167
144168
145169def calibrate (
@@ -249,7 +273,17 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
249273 max_seq_len = self .llama_meta ["get_max_seq_len" ],
250274 )
251275
252- self .llama_model = convert_pt2e (fx_graph_module )
276+ fx_graph_module = convert_pt2e (fx_graph_module )
277+
278+ logging .info ("Evaluating the converted model..." )
279+ calibrate (
280+ self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
281+ args .prompt ,
282+ fx_graph_module ,
283+ tokenizer_model_path = args .tokenizer_model ,
284+ max_seq_len = self .llama_meta ["get_max_seq_len" ],
285+ )
286+ self .llama_model = fx_graph_module
253287
254288 def lowering_modules (
255289 self ,
0 commit comments