@@ -164,6 +164,7 @@ def _model_call(self, inps):
164164def  gen_eval_wrapper (
165165    model_name : str ,
166166    args : argparse .ArgumentParser ,
167+     llm_config = None ,
167168):
168169    """ 
169170    Generates a wrapper interface around the provided model and tokenizer for 
@@ -172,7 +173,15 @@ def gen_eval_wrapper(
172173    Returns: 
173174        eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. 
174175    """ 
175-     tokenizer  =  get_tokenizer (args .tokenizer_path )  # pyre-ignore 
176+     # If llm_config is not provided, convert args to llm_config 
177+     if  llm_config  is  None :
178+         from  executorch .examples .models .llama .config .llm_config_utils  import  (
179+             convert_args_to_llm_config ,
180+         )
181+ 
182+         llm_config  =  convert_args_to_llm_config (args )
183+ 
184+     tokenizer  =  get_tokenizer (llm_config .base .tokenizer_path )
176185
177186    # ExecuTorch Binary Evaluation 
178187    if  (model  :=  args .pte ) is  not None :  # pyre-ignore 
@@ -182,7 +191,7 @@ def gen_eval_wrapper(
182191                model = model ,
183192                tokenizer = tokenizer ,
184193                tokenizer_bin = tokenizer_bin ,
185-                 max_seq_length = args . max_seq_length ,   # pyre-ignore 
194+                 max_seq_length = llm_config . export . max_seq_length ,
186195            )
187196
188197        # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings 
@@ -191,12 +200,14 @@ def gen_eval_wrapper(
191200            tokenizer = tokenizer ,
192201            # Exported model takes at most (max_seq_length - 1) tokens. 
193202            # Note that the eager model takes at most max_seq_length tokens. 
194-             max_seq_length = args .max_seq_length  -  1 ,
203+             max_seq_length = llm_config . export .max_seq_length  -  1 ,
195204        )
196205
197-     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (args )
206+     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (
207+         llm_config 
208+     )
198209    # GPTFastEvalWrapper: Create a wrapper around a pre-exported model 
199-     manager : LLMEdgeManager  =  _prepare_for_llama_export (args )
210+     manager : LLMEdgeManager  =  _prepare_for_llama_export (llm_config )
200211
201212    if  len (quantizers ) !=  0 :
202213        manager  =  manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +219,9 @@ def gen_eval_wrapper(
208219        return  GraphModuleEvalWrapper (
209220            model = model ,
210221            tokenizer = tokenizer ,
211-             max_seq_length = args .max_seq_length ,
212-             use_kv_cache = args . use_kv_cache ,   # pyre-ignore 
213-             enable_dynamic_shape = args . enable_dynamic_shape ,   # pyre-ignore 
222+             max_seq_length = llm_config . export .max_seq_length ,
223+             use_kv_cache = llm_config . model . use_kv_cache ,
224+             enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
214225        )
215226    else :
216227        # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch 
@@ -234,8 +245,8 @@ def gen_eval_wrapper(
234245        return  EagerEvalWrapper (
235246            model = model ,
236247            tokenizer = tokenizer ,
237-             max_seq_length = args .max_seq_length ,
238-             use_kv_cache = args .use_kv_cache ,
248+             max_seq_length = llm_config . export .max_seq_length ,
249+             use_kv_cache = llm_config . model .use_kv_cache ,
239250        )
240251
241252
@@ -296,12 +307,18 @@ def eval_llama(
296307    model_name : str ,
297308    args : argparse .ArgumentParser ,
298309) ->  None :
310+     # Convert args to LlmConfig 
311+     from  executorch .examples .models .llama .config .llm_config_utils  import  (
312+         convert_args_to_llm_config ,
313+     )
314+ 
315+     llm_config  =  convert_args_to_llm_config (args )
316+ 
299317    # Generate the eval wrapper 
300-     eval_wrapper  =  gen_eval_wrapper (model_name , args )
318+     eval_wrapper  =  gen_eval_wrapper (model_name , args ,  llm_config )
301319
302320    # Needed for loading mmlu dataset. 
303321    # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files 
304-     # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` 
305322    if  args .tasks  and  "mmlu"  in  args .tasks :
306323        import  datasets 
307324
@@ -312,8 +329,8 @@ def eval_llama(
312329        eval_results  =  simple_evaluate (
313330            model = eval_wrapper ,
314331            tasks = args .tasks ,
315-             num_fewshot = args .num_fewshot ,   # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` 
316-             limit = args .limit ,   # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` 
332+             num_fewshot = args .num_fewshot ,
333+             limit = args .limit ,
317334        )
318335
319336    for  task , res  in  eval_results ["results" ].items ():
@@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326343
327344    This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py 
328345    """ 
329-     assert  args .use_attention_sink  is  not None   # pyre-ignore [16] 
330-     assert  args .attention_sink_eval_tokens  >  0   # pyre-ignore [16] 
331-     attention_sink_params  =  args .use_attention_sink .split ("," )
346+     # Convert args to LlmConfig 
347+     from  executorch .examples .models .llama .config .llm_config_utils  import  (
348+         convert_args_to_llm_config ,
349+     )
350+ 
351+     llm_config  =  convert_args_to_llm_config (args )
352+ 
353+     assert  llm_config .model .use_attention_sink  is  not None 
354+     assert  args .attention_sink_eval_tokens  >  0 
355+     attention_sink_params  =  llm_config .model .use_attention_sink .split ("," )
332356    assert  len (attention_sink_params ) ==  3 
333357    sink_size  =  int (attention_sink_params [0 ])
334358    window_size  =  int (attention_sink_params [1 ])
335359
336-     assert  args . max_seq_length  ==  sink_size  +  window_size    # pyre-ignore [16] 
360+     assert  llm_config . export . max_seq_length  ==  sink_size  +  window_size 
337361
338362    device  =  "cuda"  if  torch .cuda .is_available () else  "cpu" 
339-     manager : LLMEdgeManager  =  _prepare_for_llama_export (args )
363+     manager : LLMEdgeManager  =  _prepare_for_llama_export (llm_config )
340364    model  =  manager .model .eval ().to (device = device )
341-     tokenizer  =  get_tokenizer (args . tokenizer_path )   # pyre-ignore [16] 
365+     tokenizer  =  get_tokenizer (llm_config . base . tokenizer_path )
342366
343367    eval_data  =  load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344368
@@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347371    progress_bar  =  tqdm (total = args .attention_sink_eval_tokens )
348372    input_pos  =  0 
349373    while  input_pos  <  args .attention_sink_eval_tokens :
350-         for  text  in  eval_data ["text" ]:   # pyre-ignore [16] 
374+         for  text  in  eval_data ["text" ]:
351375            tokens  =  tokenizer .encode (text , bos = False , eos = False )
352376            if  len (tokens ) <=  0 :
353377                continue 
0 commit comments