@@ -164,6 +164,7 @@ def _model_call(self, inps):
164164def gen_eval_wrapper (
165165 model_name : str ,
166166 args : argparse .ArgumentParser ,
167+ llm_config = None ,
167168):
168169 """
169170 Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,13 @@ def gen_eval_wrapper(
172173 Returns:
173174 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174175 """
175- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
176+ # If llm_config is not provided, convert args to llm_config
177+ if llm_config is None :
178+ from executorch .examples .models .llama .config .llm_config import LlmConfig
179+
180+ llm_config = LlmConfig .from_args (args )
181+
182+ tokenizer = get_tokenizer (llm_config .base .tokenizer_path )
176183
177184 # ExecuTorch Binary Evaluation
178185 if (model := args .pte ) is not None : # pyre-ignore
@@ -182,7 +189,7 @@ def gen_eval_wrapper(
182189 model = model ,
183190 tokenizer = tokenizer ,
184191 tokenizer_bin = tokenizer_bin ,
185- max_seq_length = args . max_seq_length , # pyre-ignore
192+ max_seq_length = llm_config . export . max_seq_length ,
186193 )
187194
188195 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +198,14 @@ def gen_eval_wrapper(
191198 tokenizer = tokenizer ,
192199 # Exported model takes at most (max_seq_length - 1) tokens.
193200 # Note that the eager model takes at most max_seq_length tokens.
194- max_seq_length = args .max_seq_length - 1 ,
201+ max_seq_length = llm_config . export .max_seq_length - 1 ,
195202 )
196203
197- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
204+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
205+ llm_config
206+ )
198207 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199- manager : LLMEdgeManager = _prepare_for_llama_export (args )
208+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
200209
201210 if len (quantizers ) != 0 :
202211 manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +217,9 @@ def gen_eval_wrapper(
208217 return GraphModuleEvalWrapper (
209218 model = model ,
210219 tokenizer = tokenizer ,
211- max_seq_length = args .max_seq_length ,
212- use_kv_cache = args . use_kv_cache , # pyre-ignore
213- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
220+ max_seq_length = llm_config . export .max_seq_length ,
221+ use_kv_cache = llm_config . model . use_kv_cache ,
222+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
214223 )
215224 else :
216225 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +243,8 @@ def gen_eval_wrapper(
234243 return EagerEvalWrapper (
235244 model = model ,
236245 tokenizer = tokenizer ,
237- max_seq_length = args .max_seq_length ,
238- use_kv_cache = args .use_kv_cache ,
246+ max_seq_length = llm_config . export .max_seq_length ,
247+ use_kv_cache = llm_config . model .use_kv_cache ,
239248 )
240249
241250
@@ -296,12 +305,16 @@ def eval_llama(
296305 model_name : str ,
297306 args : argparse .ArgumentParser ,
298307) -> None :
308+ # Convert args to LlmConfig
309+ from executorch .examples .models .llama .config .llm_config import LlmConfig
310+
311+ llm_config = LlmConfig .from_args (args )
312+
299313 # Generate the eval wrapper
300- eval_wrapper = gen_eval_wrapper (model_name , args )
314+ eval_wrapper = gen_eval_wrapper (model_name , args , llm_config )
301315
302316 # Needed for loading mmlu dataset.
303317 # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305318 if args .tasks and "mmlu" in args .tasks :
306319 import datasets
307320
@@ -312,8 +325,8 @@ def eval_llama(
312325 eval_results = simple_evaluate (
313326 model = eval_wrapper ,
314327 tasks = args .tasks ,
315- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
328+ num_fewshot = args .num_fewshot ,
329+ limit = args .limit ,
317330 )
318331
319332 for task , res in eval_results ["results" ].items ():
@@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326339
327340 This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328341 """
329- assert args .use_attention_sink is not None # pyre-ignore [16]
330- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331- attention_sink_params = args .use_attention_sink .split ("," )
342+ # Convert args to LlmConfig
343+ from executorch .examples .models .llama .config .llm_config import LlmConfig
344+
345+ llm_config = LlmConfig .from_args (args )
346+
347+ assert llm_config .model .use_attention_sink is not None
348+ assert args .attention_sink_eval_tokens > 0
349+ attention_sink_params = llm_config .model .use_attention_sink .split ("," )
332350 assert len (attention_sink_params ) == 3
333351 sink_size = int (attention_sink_params [0 ])
334352 window_size = int (attention_sink_params [1 ])
335353
336- assert args . max_seq_length == sink_size + window_size # pyre-ignore [16]
354+ assert llm_config . export . max_seq_length == sink_size + window_size
337355
338356 device = "cuda" if torch .cuda .is_available () else "cpu"
339- manager : LLMEdgeManager = _prepare_for_llama_export (args )
357+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
340358 model = manager .model .eval ().to (device = device )
341- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore [16]
359+ tokenizer = get_tokenizer (llm_config . base . tokenizer_path )
342360
343361 eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344362
@@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347365 progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348366 input_pos = 0
349367 while input_pos < args .attention_sink_eval_tokens :
350- for text in eval_data ["text" ]: # pyre-ignore [16]
368+ for text in eval_data ["text" ]:
351369 tokens = tokenizer .encode (text , bos = False , eos = False )
352370 if len (tokens ) <= 0 :
353371 continue
0 commit comments