2727from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2828
2929from executorch .extension .export_util .utils import export_to_edge , save_pte_program
30- from executorch .extension .llm .tokenizer .utils import get_tokenizer
3130from torch ._export import capture_pre_autograd_graph
3231from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
3332from torch .ao .quantization .quantizer import Quantizer
@@ -67,11 +66,6 @@ def __init__(
6766 use_kv_cache ,
6867 example_inputs ,
6968 enable_dynamic_shape : bool = False ,
70- calibration_tasks : Optional [List [str ]] = None ,
71- calibration_limit : Optional [int ] = None ,
72- calibration_seq_length : Optional [int ] = None ,
73- calibration_data : Optional [str ] = None ,
74- tokenizer_path : Optional [str ] = None ,
7569 verbose : bool = False ,
7670 metadata : Optional [dict ] = None ,
7771 dynamic_shapes : Optional [Any ] = None ,
@@ -93,11 +87,6 @@ def __init__(
9387 self .output_dir = "."
9488 self .dynamic_shapes = dynamic_shapes
9589 self ._saved_pte_filename = None
96- self .calibration_tasks = calibration_tasks
97- self .calibration_limit = calibration_limit
98- self .calibration_seq_length = calibration_seq_length
99- self .calibration_data = calibration_data
100- self .tokenizer_path = tokenizer_path
10190
10291 def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
10392 """
@@ -178,69 +167,6 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
178167 )
179168 return self
180169
181- def pt2e_calibrate (
182- self ,
183- prepared_module ,
184- calibration_tasks ,
185- calibration_limit ,
186- calibration_seq_length ,
187- calibration_data ,
188- tokenizer_path ,
189- ):
190- logging .info ("Run calibration..." )
191- try :
192- from executorch .examples .models .llama2 .eval_llama_lib import (
193- GraphModuleEvalWrapper ,
194- )
195- from executorch .examples .models .llama2 .evaluate import evaluate_model
196- except ImportError :
197- raise ImportError (
198- "Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
199- )
200-
201- tokenizer = get_tokenizer (tokenizer_path )
202-
203- def calibrate_template (
204- module : torch .fx .GraphModule , tokenizer , prompts : str , max_len : int
205- ):
206- # TODO: change criteria & support batch inputs if necessary
207- pos = torch .tensor (0 , dtype = torch .int64 )
208- token_list = tokenizer .encode (prompts , bos = True , eos = False )
209-
210- with torch .no_grad ():
211- while token_list [- 1 ] != tokenizer .eos_id and pos < max_len :
212- logits = module (
213- torch .full ((1 , 1 ), token_list [pos ]),
214- torch .tensor ((pos ,)),
215- )
216- pos += 1
217- if pos >= len (token_list ):
218- token_list .append (torch .argmax (logits [:], dim = - 1 ).item ())
219-
220- calibrate_template (
221- module = prepared_module ,
222- tokenizer = tokenizer ,
223- prompts = calibration_data ,
224- max_len = calibration_seq_length ,
225- )
226-
227- eval_wrapper = GraphModuleEvalWrapper (
228- model = prepared_module ,
229- tokenizer = tokenizer ,
230- max_seq_length = calibration_seq_length ,
231- use_kv_cache = self .use_kv_cache ,
232- enable_dynamic_shape = self .enable_dynamic_shape ,
233- )
234- eval_results = evaluate_model (
235- eval_wrapper ,
236- calibration_tasks ,
237- calibration_limit ,
238- )
239-
240- for task , res in eval_results ["results" ].items ():
241- print (f"{ task } : { res } " )
242- logging .info ("Calibration finish..." )
243-
244170 def pt2e_quantize (self , quantizers : Optional [List [Quantizer ]]) -> "LLMEdgeManager" :
245171 """
246172 Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
@@ -263,33 +189,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
263189 self .pre_autograd_graph_module is not None
264190 ), "Please run capture_pre_autograd_graph first"
265191 m = prepare_pt2e (self .pre_autograd_graph_module , composed_quantizer )
266- logging .info (
267- f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
268- )
269192 # Calibrate
270- if (
271- self .calibration_tasks is not None
272- and self .calibration_limit is not None
273- and self .calibration_seq_length is not None
274- and self .calibration_data is not None
275- and self .tokenizer_path is not None
276- ):
277- logging .info (
278- f"Calibrating with tasks: { self .calibration_tasks } , limit: { self .calibration_limit } , calibration_data: { self .calibration_data } , tokenizer_path: { self .tokenizer_path } , seq_length: { self .calibration_seq_length } "
279- )
280- self .pt2e_calibrate (
281- prepared_module = m ,
282- calibration_tasks = self .calibration_tasks ,
283- calibration_limit = self .calibration_limit ,
284- calibration_seq_length = self .calibration_seq_length ,
285- calibration_data = self .calibration_data ,
286- tokenizer_path = self .tokenizer_path ,
287- )
288- else :
289- logging .info (
290- "No calibration provided, using dummy input to calibrate..."
291- )
292- m (* self .example_inputs )
193+ m (* self .example_inputs )
293194 m = convert_pt2e (m )
294195 DuplicateDynamicQuantChainPass ()(m )
295196 self .pre_autograd_graph_module = m
0 commit comments