44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7-
87import argparse
9-
108from typing import Optional , Union
119
1210import torch
13-
14- from datasets import load_dataset
1511from executorch .examples .models .llama .export_llama_lib import (
12+ _convert_args_to_config ,
13+ _prepare_for_llama_export ,
14+ build_args_parser as _build_args_parser ,
1615 get_quantizer_and_quant_params ,
1716)
1817from executorch .examples .models .llama .tokenizer .tiktoken import Tokenizer as Tiktoken
19-
2018from executorch .extension .llm .export .builder import LLMEdgeManager
2119from executorch .extension .llm .tokenizer .tokenizer import (
2220 Tokenizer as SentencePieceTokenizer ,
2321)
2422from executorch .extension .llm .tokenizer .utils import get_tokenizer
2523from lm_eval .evaluator import simple_evaluate
26- from torch .nn import CrossEntropyLoss
27- from tqdm import tqdm
24+ from omegaconf import DictConfig , OmegaConf
2825
2926from .evaluate .eager_eval import EagerEvalWrapper
3027
31- from .export_llama_lib import (
32- _prepare_for_llama_export ,
33- build_args_parser as _build_args_parser ,
34- )
35-
3628
3729class GraphModuleEvalWrapper (EagerEvalWrapper ):
3830 """
@@ -165,7 +157,7 @@ def _model_call(self, inps):
165157
166158def gen_eval_wrapper (
167159 model_name : str ,
168- args : argparse . ArgumentParser ,
160+ config : DictConfig ,
169161):
170162 """
171163 Generates a wrapper interface around the provided model and tokenizer for
@@ -174,17 +166,17 @@ def gen_eval_wrapper(
174166 Returns:
175167 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
176168 """
177- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore
169+ tokenizer = get_tokenizer (config . model . tokenizer_path )
178170
179171 # ExecuTorch Binary Evaluation
180- if (model := args . pte ) is not None : # pyre-ignore
181- if (tokenizer_bin := args . tokenizer_bin ) is not None : # pyre-ignore
172+ if (model := config . eval . pte ) is not None :
173+ if (tokenizer_bin := config . eval . tokenizer_bin ) is not None :
182174 # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
183175 return ETRunnerEvalWrapper (
184176 model = model ,
185177 tokenizer = tokenizer ,
186178 tokenizer_bin = tokenizer_bin ,
187- max_seq_length = args . max_seq_length , # pyre-ignore
179+ max_seq_length = config . sequence . max_seq_length ,
188180 )
189181
190182 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -193,12 +185,12 @@ def gen_eval_wrapper(
193185 tokenizer = tokenizer ,
194186 # Exported model takes at most (max_seq_length - 1) tokens.
195187 # Note that the eager model takes at most max_seq_length tokens.
196- max_seq_length = args .max_seq_length - 1 ,
188+ max_seq_length = config . sequence .max_seq_length - 1 ,
197189 )
198190
199- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
191+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (config )
200192 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
201- manager : LLMEdgeManager = _prepare_for_llama_export (args )
193+ manager : LLMEdgeManager = _prepare_for_llama_export (config )
202194
203195 if len (quantizers ) != 0 :
204196 manager = manager .export ().pt2e_quantize (quantizers )
@@ -210,9 +202,9 @@ def gen_eval_wrapper(
210202 return GraphModuleEvalWrapper (
211203 model = model ,
212204 tokenizer = tokenizer ,
213- max_seq_length = args .max_seq_length ,
214- use_kv_cache = args .use_kv_cache , # pyre-ignore
215- enable_dynamic_shape = args .enable_dynamic_shape , # pyre-ignore
205+ max_seq_length = config . sequence .max_seq_length ,
206+ use_kv_cache = config . kv_cache .use_kv_cache , # pyre-ignore
207+ enable_dynamic_shape = config . misc .enable_dynamic_shape , # pyre-ignore
216208 )
217209 else :
218210 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -230,18 +222,94 @@ def gen_eval_wrapper(
230222 # that is not available in this eval_llama. We save the checkpoint
231223 # here for consistency with eval_llama. The accuracy results we
232224 # get from eval_llama can be used as a reference to other evaluations.
233- if args .output_eager_checkpoint_file is not None : # pyre-ignore
234- torch .save (model , args .output_eager_checkpoint_file )
225+ if config . eval .output_eager_checkpoint_file is not None : # pyre-ignore
226+ torch .save (model , config . eval .output_eager_checkpoint_file )
235227
236228 return EagerEvalWrapper (
237229 model = model ,
238230 tokenizer = tokenizer ,
239- max_seq_length = args .max_seq_length ,
240- use_kv_cache = args .use_kv_cache ,
231+ max_seq_length = config .sequence .max_seq_length ,
232+ use_kv_cache = config .kv_cache .use_kv_cache ,
233+ )
234+
235+
236+ def eval_llama (
237+ model_name : str ,
238+ config : DictConfig ,
239+ ) -> None :
240+ # Generate the eval wrapper
241+ eval_wrapper = gen_eval_wrapper (model_name , config )
242+
243+ # Needed for loading mmlu dataset.
244+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
245+ if config .eval .tasks and "mmlu" in config .eval .tasks :
246+ import datasets
247+
248+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
249+
250+ # Evaluate the model
251+ tasks = (
252+ None if config .eval .tasks is None else OmegaConf .to_container (config .eval .tasks )
253+ )
254+ with torch .no_grad ():
255+ eval_results = simple_evaluate (
256+ model = eval_wrapper ,
257+ tasks = tasks ,
258+ num_fewshot = config .eval .num_fewshot ,
259+ limit = config .eval .limit ,
260+ )
261+
262+ for task , res in eval_results ["results" ].items ():
263+ print (f"{ task } : { res } " )
264+
265+
266+ def eval_llama_with_attention_sink (
267+ model_name : str ,
268+ config : DictConfig ,
269+ ) -> None :
270+ # Generate the eval wrapper
271+ eval_wrapper = gen_eval_wrapper (model_name , config )
272+
273+ # Needed for loading mmlu dataset.
274+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
275+ if config .eval .tasks and "mmlu" in config .eval .tasks :
276+ import datasets
277+
278+ datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
279+
280+ # Evaluate the model
281+ with torch .no_grad ():
282+ eval_results = simple_evaluate (
283+ model = eval_wrapper ,
284+ tasks = OmegaConf .to_container (config .eval .tasks ),
285+ num_fewshot = config .eval .num_fewshot ,
286+ limit = config .eval .limit ,
241287 )
242288
289+ for task , res in eval_results ["results" ].items ():
290+ print (f"{ task } : { res } " )
291+
292+
293+ def _convert_cli_to_config_format (args ) -> DictConfig :
294+ """Convert CLI arguments to config format."""
295+ # First convert common args using the shared function
296+ config = _convert_args_to_config (args )
297+
298+ # Add evaluation-specific settings
299+ config .eval = OmegaConf .create ()
300+ config .eval .tasks = args .tasks
301+ config .eval .limit = args .limit
302+ config .eval .num_fewshot = args .num_fewshot
303+ config .eval .pte = args .pte
304+ config .eval .tokenizer_bin = args .tokenizer_bin
305+ config .eval .output_eager_checkpoint_file = args .output_eager_checkpoint_file
306+ config .eval .attention_sink_eval_tokens = args .attention_sink_eval_tokens
307+
308+ return config
309+
243310
244311def build_args_parser () -> argparse .ArgumentParser :
312+ """Build argument parser for evaluation, extending the export parser with eval-specific args."""
245313 # Start with arg parser from export_llama_lib
246314 parser = _build_args_parser ()
247315
@@ -288,92 +356,7 @@ def build_args_parser() -> argparse.ArgumentParser:
288356 help = "Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint." ,
289357 )
290358
291- # Set of parameters secpific to AttentionSink.
359+ # Set of parameters specific to AttentionSink.
292360 parser .add_argument ("--attention_sink_eval_tokens" , type = int , default = 0 )
293361
294362 return parser
295-
296-
297- def eval_llama (
298- model_name : str ,
299- args : argparse .ArgumentParser ,
300- ) -> None :
301- # Generate the eval wrapper
302- eval_wrapper = gen_eval_wrapper (model_name , args )
303-
304- # Needed for loading mmlu dataset.
305- # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
306- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
307- if args .tasks and "mmlu" in args .tasks :
308- import datasets
309-
310- datasets .config .HF_DATASETS_TRUST_REMOTE_CODE = True
311-
312- # Evaluate the model
313- with torch .no_grad ():
314- eval_results = simple_evaluate (
315- model = eval_wrapper ,
316- tasks = args .tasks ,
317- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
318- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
319- )
320-
321- for task , res in eval_results ["results" ].items ():
322- print (f"{ task } : { res } " )
323-
324-
325- def eval_llama_with_attention_sink (model_name : str , args : argparse .ArgumentParser ):
326- """
327- Evaluate the model's perplexity when AttentionSink is enabled.
328-
329- This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
330- """
331- assert args .use_attention_sink is not None # pyre-ignore [16]
332- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
333- attention_sink_params = args .use_attention_sink .split ("," )
334- assert len (attention_sink_params ) == 3
335- sink_size = int (attention_sink_params [0 ])
336- window_size = int (attention_sink_params [1 ])
337-
338- assert args .max_seq_length == sink_size + window_size # pyre-ignore [16]
339-
340- device = "cuda" if torch .cuda .is_available () else "cpu"
341- manager : LLMEdgeManager = _prepare_for_llama_export (args )
342- model = manager .model .eval ().to (device = device )
343- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore [16]
344-
345- eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
346-
347- nlls = []
348- loss_fn = CrossEntropyLoss (reduction = "none" )
349- progress_bar = tqdm (total = args .attention_sink_eval_tokens )
350- input_pos = 0
351- while input_pos < args .attention_sink_eval_tokens :
352- for text in eval_data ["text" ]: # pyre-ignore [16]
353- tokens = tokenizer .encode (text , bos = False , eos = False )
354- if len (tokens ) <= 0 :
355- continue
356- with torch .no_grad ():
357- num_tokens = min (
358- len (tokens ) - 1 , args .attention_sink_eval_tokens - input_pos
359- )
360- logits = model (
361- torch .tensor (
362- [tokens [:num_tokens ]], dtype = torch .int64 , device = device
363- ),
364- torch .tensor ([input_pos ], dtype = torch .int64 , device = device ),
365- ).squeeze (dim = 0 )
366- neg_log_likelihood = loss_fn (
367- logits ,
368- torch .tensor (
369- [tokens [1 : num_tokens + 1 ]], dtype = torch .int64 , device = device
370- ).view (- 1 ),
371- )
372- nlls .append (neg_log_likelihood )
373- input_pos += num_tokens
374- progress_bar .update (num_tokens )
375- if input_pos >= args .attention_sink_eval_tokens :
376- break
377- ppl = torch .exp (torch .cat (nlls ).mean ())
378- print (f"Perplexity: { ppl .item ()} " )
379- return ppl .item ()
0 commit comments