Skip to content

Commit 501be36

Browse files
committed
add option to run mmlu with 5 shots
1 parent df5b2ab commit 501be36

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,16 @@ def build_args_parser() -> argparse.ArgumentParser:
246246
help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
247247
)
248248
parser.add_argument(
249-
"--limit", type=int, default=5, help="number of samples to evalulate"
249+
"--limit", type=int, default=None, help="number of samples to evalulate"
250+
)
251+
parser.add_argument(
252+
"-f",
253+
"--num_fewshot",
254+
type=int,
255+
default=None,
256+
metavar="N",
257+
help="Number of examples in few-shot context",
250258
)
251-
252259
# Add additional args specific to eval via an ET Runner
253260
# Note: For initial integration, the tokenizer.model is also required
254261
parser.add_argument(
@@ -282,9 +289,10 @@ def eval_llama(
282289

283290
# Evaluate the model
284291
eval_results = evaluate_model(
285-
eval_wrapper,
286-
args.tasks, # pyre-ignore
287-
args.limit, # pyre-ignore
292+
eval_wrapper=eval_wrapper,
293+
tasks=args.tasks,
294+
num_fewshot=args.num_fewshot,
295+
limit=args.limit,
288296
)
289297

290298
for task, res in eval_results["results"].items():

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
)
1616

1717
from lm_eval.api.model import LM
18-
from lm_eval.evaluator import evaluate
18+
from lm_eval.evaluator import simple_evaluate
1919
from lm_eval.models.huggingface import HFLM as eval_wrapper
20-
from lm_eval.tasks import get_task_dict
2120

2221
from torch import nn
2322

@@ -85,6 +84,7 @@ def _model_generate(self, context, max_length, eos_token_id):
8584
def evaluate_model(
8685
eval_wrapper: LM,
8786
tasks: Optional[list] = None,
87+
num_fewshot: Optional[int] = None,
8888
limit: Optional[int] = None,
8989
) -> dict:
9090
"""
@@ -93,6 +93,7 @@ def evaluate_model(
9393
Args:
9494
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
9595
tasks: Optional[list]: The names of the evaluation tasks to perform.
96+
num_fewshot: Optional[int]: Number of examples in few-shot context.
9697
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
9798
9899
Returns:
@@ -102,16 +103,10 @@ def evaluate_model(
102103
if tasks is None:
103104
tasks = ["wikitext"]
104105

105-
if "hendrycks_test" in tasks:
106-
tasks.remove("hendrycks_test")
107-
tasks += list(
108-
lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore
109-
)
110-
task_dict = get_task_dict(tasks)
111-
112-
eval_results = evaluate(
113-
eval_wrapper,
114-
task_dict,
106+
eval_results = simple_evaluate(
107+
model=eval_wrapper,
108+
tasks=tasks,
109+
num_fewshot=num_fewshot,
115110
limit=limit,
116111
)
117112
return eval_results

0 commit comments

Comments
 (0)