44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7-
87import argparse
9-
108from typing import Optional , Union
119
1210import torch
13-
14- from datasets import load_dataset
1511from executorch .examples .models .llama .export_llama_lib import (
12+ _convert_args_to_config ,
13+ _prepare_for_llama_export ,
14+ build_args_parser as _build_args_parser ,
1615 get_quantizer_and_quant_params ,
1716)
1817
1918from executorch .extension .llm .export .builder import LLMEdgeManager
2019from lm_eval .evaluator import simple_evaluate
20+ from omegaconf import DictConfig , OmegaConf
2121from pytorch_tokenizers import get_tokenizer
2222from pytorch_tokenizers .llama2c import Llama2cTokenizer as SentencePieceTokenizer
2323from pytorch_tokenizers .tiktoken import TiktokenTokenizer as Tiktoken
24- from torch .nn import CrossEntropyLoss
25- from tqdm import tqdm
2624
2725from .evaluate .eager_eval import EagerEvalWrapper
2826
29- from .export_llama_lib import (
30- _prepare_for_llama_export ,
31- build_args_parser as _build_args_parser ,
32- )
33-
3427
3528class GraphModuleEvalWrapper (EagerEvalWrapper ):
3629 """
@@ -163,7 +156,7 @@ def _model_call(self, inps):
163156
164157def gen_eval_wrapper (
165158 model_name : str ,
166- args : argparse . ArgumentParser ,
159+ config : DictConfig ,
167160):
168161 """
169162 Generates a wrapper interface around the provided model and tokenizer for
@@ -172,17 +165,17 @@ def gen_eval_wrapper(
172165 Returns:
173166 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174167 """
175- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore
168+ tokenizer = get_tokenizer (config . model . tokenizer_path )
176169
177170 # ExecuTorch Binary Evaluation
178- if (model := args . pte ) is not None : # pyre-ignore
179- if (tokenizer_bin := args . tokenizer_bin ) is not None : # pyre-ignore
171+ if (model := config . eval . pte ) is not None :
172+ if (tokenizer_bin := config . eval . tokenizer_bin ) is not None :
180173 # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
181174 return ETRunnerEvalWrapper (
182175 model = model ,
183176 tokenizer = tokenizer ,
184177 tokenizer_bin = tokenizer_bin ,
185- max_seq_length = args . max_seq_length , # pyre-ignore
178+ max_seq_length = config . sequence . max_seq_length ,
186179 )
187180
188181 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +184,12 @@ def gen_eval_wrapper(
191184 tokenizer = tokenizer ,
192185 # Exported model takes at most (max_seq_length - 1) tokens.
193186 # Note that the eager model takes at most max_seq_length tokens.
194- max_seq_length = args .max_seq_length - 1 ,
187+ max_seq_length = config . sequence .max_seq_length - 1 ,
195188 )
196189
197- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
190+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (config )
198191 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199- manager : LLMEdgeManager = _prepare_for_llama_export (args )
192+ manager : LLMEdgeManager = _prepare_for_llama_export (config )
200193
201194 if len (quantizers ) != 0 :
202195 manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +201,9 @@ def gen_eval_wrapper(
208201 return GraphModuleEvalWrapper (
209202 model = model ,
210203 tokenizer = tokenizer ,
211- max_seq_length = args .max_seq_length ,
212- use_kv_cache = args . use_kv_cache , # pyre-ignore
213- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
204+ max_seq_length = config . sequence .max_seq_length ,
205+ use_kv_cache = config . kv_cache . use_kv_cache ,
206+ enable_dynamic_shape = config . misc . enable_dynamic_shape ,
214207 )
215208 else :
216209 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -228,18 +221,94 @@ def gen_eval_wrapper(
228221 # that is not available in this eval_llama. We save the checkpoint
229222 # here for consistency with eval_llama. The accuracy results we
230223 # get from eval_llama can be used as a reference to other evaluations.
231- if args . output_eager_checkpoint_file is not None : # pyre-ignore
232- torch .save (model , args .output_eager_checkpoint_file )
224+ if config . eval . output_eager_checkpoint_file is not None :
225+ torch .save (model , config . eval .output_eager_checkpoint_file )
233226
234227 return EagerEvalWrapper (
235228 model = model ,
236229 tokenizer = tokenizer ,
237- max_seq_length = args .max_seq_length ,
238- use_kv_cache = args .use_kv_cache ,
230+ max_seq_length = config .sequence .max_seq_length ,
231+ use_kv_cache = config .kv_cache .use_kv_cache ,
232+ )
233+
234+
235+ def eval_llama (
236+ model_name : str ,
237+ config : DictConfig ,
238+ ) -> None :
239+ # Generate the eval wrapper
240+ eval_wrapper = gen_eval_wrapper (model_name , config )
241+
242+ # Needed for loading mmlu dataset.
243+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
244+ if config .eval .tasks and "mmlu" in config .eval .tasks :
245+ import datasets
246+
247+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
248+
249+ # Evaluate the model
250+ tasks = (
251+ None if config .eval .tasks is None else OmegaConf .to_container (config .eval .tasks )
252+ )
253+ with torch .no_grad ():
254+ eval_results = simple_evaluate (
255+ model = eval_wrapper ,
256+ tasks = tasks ,
257+ num_fewshot = config .eval .num_fewshot ,
258+ limit = config .eval .limit ,
239259 )
240260
261+ for task , res in eval_results ["results" ].items ():
262+ print (f"{ task } : { res } " )
263+
264+
265+ def eval_llama_with_attention_sink (
266+ model_name : str ,
267+ config : DictConfig ,
268+ ) -> None :
269+ # Generate the eval wrapper
270+ eval_wrapper = gen_eval_wrapper (model_name , config )
271+
272+ # Needed for loading mmlu dataset.
273+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
274+ if config .eval .tasks and "mmlu" in config .eval .tasks :
275+ import datasets
276+
277+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
278+
279+ # Evaluate the model
280+ with torch .no_grad ():
281+ eval_results = simple_evaluate (
282+ model = eval_wrapper ,
283+ tasks = OmegaConf .to_container (config .eval .tasks ),
284+ num_fewshot = config .eval .num_fewshot ,
285+ limit = config .eval .limit ,
286+ )
287+
288+ for task , res in eval_results ["results" ].items ():
289+ print (f"{ task } : { res } " )
290+
291+
292+ def _convert_cli_to_config_format (args ) -> DictConfig :
293+ """Convert CLI arguments to config format."""
294+ # First convert common args using the shared function
295+ config = _convert_args_to_config (args )
296+
297+ # Add evaluation-specific settings
298+ config .eval = OmegaConf .create ()
299+ config .eval .tasks = args .tasks
300+ config .eval .limit = args .limit
301+ config .eval .num_fewshot = args .num_fewshot
302+ config .eval .pte = args .pte
303+ config .eval .tokenizer_bin = args .tokenizer_bin
304+ config .eval .output_eager_checkpoint_file = args .output_eager_checkpoint_file
305+ config .eval .attention_sink_eval_tokens = args .attention_sink_eval_tokens
306+
307+ return config
308+
241309
242310def build_args_parser () -> argparse .ArgumentParser :
311+ """Build argument parser for evaluation, extending the export parser with eval-specific args."""
243312 # Start with arg parser from export_llama_lib
244313 parser = _build_args_parser ()
245314
@@ -286,92 +355,7 @@ def build_args_parser() -> argparse.ArgumentParser:
286355 help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
287356 )
288357
289- # Set of parameters secpific to AttentionSink.
358+ # Set of parameters specific to AttentionSink.
290359 parser .add_argument ("--attention_sink_eval_tokens" , type = int , default = 0 )
291360
292361 return parser
293-
294-
295- def eval_llama (
296- model_name : str ,
297- args : argparse .ArgumentParser ,
298- ) -> None :
299- # Generate the eval wrapper
300- eval_wrapper = gen_eval_wrapper (model_name , args )
301-
302- # Needed for loading mmlu dataset.
303- # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305- if args .tasks and "mmlu" in args .tasks :
306- import datasets
307-
308- datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
309-
310- # Evaluate the model
311- with torch .no_grad ():
312- eval_results = simple_evaluate (
313- model = eval_wrapper ,
314- tasks = args .tasks ,
315- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
317- )
318-
319- for task , res in eval_results ["results" ].items ():
320- print (f"{ task } : { res } " )
321-
322-
323- def eval_llama_with_attention_sink (model_name : str , args : argparse .ArgumentParser ):
324- """
325- Evaluate the model's perplexity when AttentionSink is enabled.
326-
327- This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328- """
329- assert args .use_attention_sink is not None # pyre-ignore [16]
330- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331- attention_sink_params = args .use_attention_sink .split ("," )
332- assert len (attention_sink_params ) == 3
333- sink_size = int (attention_sink_params [0 ])
334- window_size = int (attention_sink_params [1 ])
335-
336- assert args .max_seq_length == sink_size + window_size # pyre-ignore [16]
337-
338- device = "cuda" if torch .cuda .is_available () else "cpu"
339- manager : LLMEdgeManager = _prepare_for_llama_export (args )
340- model = manager .model .eval ().to (device = device )
341- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore [16]
342-
343- eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344-
345- nlls = []
346- loss_fn = CrossEntropyLoss (reduction = "none" )
347- progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348- input_pos = 0
349- while input_pos < args .attention_sink_eval_tokens :
350- for text in eval_data ["text" ]: # pyre-ignore [16]
351- tokens = tokenizer .encode (text , bos = False , eos = False )
352- if len (tokens ) <= 0 :
353- continue
354- with torch .no_grad ():
355- num_tokens = min (
356- len (tokens ) - 1 , args .attention_sink_eval_tokens - input_pos
357- )
358- logits = model (
359- torch .tensor (
360- [tokens [:num_tokens ]], dtype = torch .int64 , device = device
361- ),
362- torch .tensor ([input_pos ], dtype = torch .int64 , device = device ),
363- ).squeeze (dim = 0 )
364- neg_log_likelihood = loss_fn (
365- logits ,
366- torch .tensor (
367- [tokens [1 : num_tokens + 1 ]], dtype = torch .int64 , device = device
368- ).view (- 1 ),
369- )
370- nlls .append (neg_log_likelihood )
371- input_pos += num_tokens
372- progress_bar .update (num_tokens )
373- if input_pos >= args .attention_sink_eval_tokens :
374- break
375- ppl = torch .exp (torch .cat (nlls ).mean ())
376- print (f"Perplexity: { ppl .item ()} " )
377- return ppl .item ()
0 commit comments