@@ -164,6 +164,7 @@ def _model_call(self, inps):
164164def gen_eval_wrapper (
165165 model_name : str ,
166166 args : argparse .ArgumentParser ,
167+ llm_config = None ,
167168):
168169 """
169170 Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,15 @@ def gen_eval_wrapper(
172173 Returns:
173174 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174175 """
175- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
176+ # If llm_config is not provided, convert args to llm_config
177+ if llm_config is None :
178+ from executorch .examples .models .llama .config .llm_config_utils import (
179+ convert_args_to_llm_config ,
180+ )
181+
182+ llm_config = convert_args_to_llm_config (args )
183+
184+ tokenizer = get_tokenizer (llm_config .base .tokenizer_path )
176185
177186 # ExecuTorch Binary Evaluation
178187 if (model := args .pte ) is not None : # pyre-ignore
@@ -182,7 +191,7 @@ def gen_eval_wrapper(
182191 model = model ,
183192 tokenizer = tokenizer ,
184193 tokenizer_bin = tokenizer_bin ,
185- max_seq_length = args . max_seq_length , # pyre-ignore
194+ max_seq_length = llm_config . export . max_seq_length ,
186195 )
187196
188197 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +200,14 @@ def gen_eval_wrapper(
191200 tokenizer = tokenizer ,
192201 # Exported model takes at most (max_seq_length - 1) tokens.
193202 # Note that the eager model takes at most max_seq_length tokens.
194- max_seq_length = args .max_seq_length - 1 ,
203+ max_seq_length = llm_config . export .max_seq_length - 1 ,
195204 )
196205
197- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
206+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
207+ llm_config
208+ )
198209 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199- manager : LLMEdgeManager = _prepare_for_llama_export (args )
210+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config , args )
200211
201212 if len (quantizers ) != 0 :
202213 manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +219,9 @@ def gen_eval_wrapper(
208219 return GraphModuleEvalWrapper (
209220 model = model ,
210221 tokenizer = tokenizer ,
211- max_seq_length = args .max_seq_length ,
212- use_kv_cache = args . use_kv_cache , # pyre-ignore
213- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
222+ max_seq_length = llm_config . export .max_seq_length ,
223+ use_kv_cache = llm_config . model . use_kv_cache ,
224+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
214225 )
215226 else :
216227 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +245,8 @@ def gen_eval_wrapper(
234245 return EagerEvalWrapper (
235246 model = model ,
236247 tokenizer = tokenizer ,
237- max_seq_length = args .max_seq_length ,
238- use_kv_cache = args .use_kv_cache ,
248+ max_seq_length = llm_config . export .max_seq_length ,
249+ use_kv_cache = llm_config . model .use_kv_cache ,
239250 )
240251
241252
@@ -296,12 +307,18 @@ def eval_llama(
296307 model_name : str ,
297308 args : argparse .ArgumentParser ,
298309) -> None :
310+ # Convert args to LlmConfig
311+ from executorch .examples .models .llama .config .llm_config_utils import (
312+ convert_args_to_llm_config ,
313+ )
314+
315+ llm_config = convert_args_to_llm_config (args )
316+
299317 # Generate the eval wrapper
300- eval_wrapper = gen_eval_wrapper (model_name , args )
318+ eval_wrapper = gen_eval_wrapper (model_name , args , llm_config )
301319
302320 # Needed for loading mmlu dataset.
303321 # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305322 if args .tasks and "mmlu" in args .tasks :
306323 import datasets
307324
@@ -312,8 +329,8 @@ def eval_llama(
312329 eval_results = simple_evaluate (
313330 model = eval_wrapper ,
314331 tasks = args .tasks ,
315- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
332+ num_fewshot = args .num_fewshot ,
333+ limit = args .limit ,
317334 )
318335
319336 for task , res in eval_results ["results" ].items ():
@@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326343
327344 This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328345 """
329- assert args .use_attention_sink is not None # pyre-ignore [16]
330- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331- attention_sink_params = args .use_attention_sink .split ("," )
346+ # Convert args to LlmConfig
347+ from executorch .examples .models .llama .config .llm_config_utils import (
348+ convert_args_to_llm_config ,
349+ )
350+
351+ llm_config = convert_args_to_llm_config (args )
352+
353+ assert llm_config .model .use_attention_sink is not None
354+ assert args .attention_sink_eval_tokens > 0
355+ attention_sink_params = llm_config .model .use_attention_sink .split ("," )
332356 assert len (attention_sink_params ) == 3
333357 sink_size = int (attention_sink_params [0 ])
334358 window_size = int (attention_sink_params [1 ])
335359
336- assert args . max_seq_length == sink_size + window_size # pyre-ignore [16]
360+ assert llm_config . export . max_seq_length == sink_size + window_size
337361
338362 device = "cuda" if torch .cuda .is_available () else "cpu"
339- manager : LLMEdgeManager = _prepare_for_llama_export (args )
363+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config , args )
340364 model = manager .model .eval ().to (device = device )
341- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore [16]
365+ tokenizer = get_tokenizer (llm_config . base . tokenizer_path )
342366
343367 eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344368
@@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347371 progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348372 input_pos = 0
349373 while input_pos < args .attention_sink_eval_tokens :
350- for text in eval_data ["text" ]: # pyre-ignore [16]
374+ for text in eval_data ["text" ]:
351375 tokens = tokenizer .encode (text , bos = False , eos = False )
352376 if len (tokens ) <= 0 :
353377 continue
0 commit comments