1010from typing import Optional , Union
1111
1212import torch
13- from executorch .examples .models .llama2 .evaluate import EagerEvalWrapper , evaluate_model
1413from executorch .examples .models .llama2 .export_llama_lib import (
1514 get_quantizer_and_quant_params ,
1615)
1716from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer as Tiktoken
1817
19- from executorch .extension .llm .export import LLMEdgeManager
18+ from executorch .extension .llm .export . builder import LLMEdgeManager
2019from executorch .extension .llm .tokenizer .tokenizer import (
2120 Tokenizer as SentencePieceTokenizer ,
2221)
2322from executorch .extension .llm .tokenizer .utils import get_tokenizer
2423from lm_eval .api .model import LM
2524
25+ from .evaluate .eager_eval import EagerEvalWrapper , evaluate_model
26+
2627from .export_llama_lib import (
2728 _prepare_for_llama_export ,
2829 build_args_parser as _build_args_parser ,
@@ -91,7 +92,7 @@ def __init__(
9192 tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
9293 max_seq_length : Optional [int ] = None ,
9394 ):
94- super ().__init__ (None , tokenizer , max_seq_length )
95+ super ().__init__ (None , tokenizer , max_seq_length ) # pyre-ignore
9596 self ._model = model # Expects model to be path to a .pte file
9697
9798 from executorch .extension .pybindings .portable_lib import _load_for_executorch
@@ -106,7 +107,7 @@ def __init__(
106107 from executorch .kernels import quantized # noqa
107108
108109 self ._et_model = _load_for_executorch (self ._model )
109- self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ]
110+ self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ] # pyre-ignore
110111
111112 def _model_call (self , inps ):
112113 # Given inps (tokens), return the logits from a single forward call
@@ -140,7 +141,7 @@ def __init__(
140141 tokenizer_bin : str ,
141142 max_seq_length : Optional [int ] = None ,
142143 ):
143- super ().__init__ (None , tokenizer , max_seq_length )
144+ super ().__init__ (None , tokenizer , max_seq_length ) # pyre-ignore
144145 self ._model = model
145146 self ._tokenizer_bin = tokenizer_bin
146147
@@ -165,17 +166,17 @@ def gen_eval_wrapper(
165166 Returns:
166167 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
167168 """
168- tokenizer = get_tokenizer (args .tokenizer_path )
169+ tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
169170
170171 # ExecuTorch Binary Evaluation
171- if (model := args .pte ) is not None :
172- if (tokenizer_bin := args .tokenizer_bin ) is not None :
172+ if (model := args .pte ) is not None : # pyre-ignore
173+ if (tokenizer_bin := args .tokenizer_bin ) is not None : # pyre-ignore
173174 # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
174175 return ETRunnerEvalWrapper (
175176 model = model ,
176177 tokenizer = tokenizer ,
177178 tokenizer_bin = tokenizer_bin ,
178- max_seq_length = args .max_seq_length ,
179+ max_seq_length = args .max_seq_length , # pyre-ignore
179180 )
180181
181182 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -194,16 +195,16 @@ def gen_eval_wrapper(
194195 if len (quantizers ) != 0 :
195196 manager = manager .capture_pre_autograd_graph ().pt2e_quantize (quantizers )
196197 model = (
197- manager .pre_autograd_graph_module .to (device = "cuda" )
198+ manager .pre_autograd_graph_module .to (device = "cuda" ) # pyre-ignore
198199 if torch .cuda .is_available ()
199200 else manager .pre_autograd_graph_module .to (device = "cpu" )
200201 )
201202 return GraphModuleEvalWrapper (
202203 model = model ,
203204 tokenizer = tokenizer ,
204205 max_seq_length = args .max_seq_length ,
205- use_kv_cache = args .use_kv_cache ,
206- enable_dynamic_shape = args .enable_dynamic_shape ,
206+ use_kv_cache = args .use_kv_cache , # pyre-ignore
207+ enable_dynamic_shape = args .enable_dynamic_shape , # pyre-ignore
207208 )
208209 else :
209210 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -221,7 +222,7 @@ def gen_eval_wrapper(
221222 # that is not available in this eval_llama. We save the checkpoint
222223 # here for consistency with eval_llama. The accuracy results we
223224 # get from eval_llama can be used as a reference to other evaluations.
224- if args .output_eager_checkpoint_file is not None :
225+ if args .output_eager_checkpoint_file is not None : # pyre-ignore
225226 torch .save (model , args .output_eager_checkpoint_file )
226227
227228 return EagerEvalWrapper (
@@ -282,8 +283,8 @@ def eval_llama(
282283 # Evaluate the model
283284 eval_results = evaluate_model (
284285 eval_wrapper ,
285- args .tasks ,
286- args .limit ,
286+ args .tasks , # pyre-ignore
287+ args .limit , # pyre-ignore
287288 )
288289
289290 for task , res in eval_results ["results" ].items ():
0 commit comments