1515)
1616
1717from lm_eval .api .model import LM
18- from lm_eval .evaluator import evaluate
18+ from lm_eval .evaluator import simple_evaluate
1919from lm_eval .models .huggingface import HFLM as eval_wrapper
20- from lm_eval .tasks import get_task_dict
2120
2221from torch import nn
2322
@@ -85,6 +84,7 @@ def _model_generate(self, context, max_length, eos_token_id):
8584def evaluate_model (
8685 eval_wrapper : LM ,
8786 tasks : Optional [list ] = None ,
87+ num_fewshot : Optional [int ] = None ,
8888 limit : Optional [int ] = None ,
8989) -> dict :
9090 """
@@ -93,6 +93,7 @@ def evaluate_model(
9393 Args:
9494 eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
9595 tasks: Optional[list]: The names of the evaluation tasks to perform.
96+ num_fewshot: Optional[int]: Number of examples in few-shot context.
9697 limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
9798
9899 Returns:
@@ -102,16 +103,10 @@ def evaluate_model(
102103 if tasks is None :
103104 tasks = ["wikitext" ]
104105
105- if "hendrycks_test" in tasks :
106- tasks .remove ("hendrycks_test" )
107- tasks += list (
108- lm_eval .tasks .hendrycks_test .create_all_tasks ().keys () # pyre-ignore
109- )
110- task_dict = get_task_dict (tasks )
111-
112- eval_results = evaluate (
113- eval_wrapper ,
114- task_dict ,
106+ eval_results = simple_evaluate (
107+ model = eval_wrapper ,
108+ tasks = tasks ,
109+ num_fewshot = num_fewshot ,
115110 limit = limit ,
116111 )
117112 return eval_results
0 commit comments