diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bd12c374b51..5470670a4cd 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -219,9 +219,7 @@ def pt2e_calibrate( from executorch.examples.models.llama.eval_llama_lib import ( GraphModuleEvalWrapper, ) - from executorch.examples.models.llama.evaluate import ( # pyre-ignore[21] - evaluate_model, - ) + from lm_eval.evaluator import simple_evaluate # pyre-ignore[21] except ImportError: raise ImportError( "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" @@ -266,11 +264,14 @@ def calibrate_template( generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, ) - eval_results = evaluate_model( - eval_wrapper, - calibration_tasks, - calibration_limit, - ) + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) for task, res in eval_results["results"].items(): print(f"{task}: {res}")