Skip to content

Commit 7c0f083

Browse files
committed
add option to run mmlu with 5 shots
1 parent 866b40c commit 7c0f083

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,14 @@ def build_args_parser() -> argparse.ArgumentParser:
248248
parser.add_argument(
249249
"--limit", type=int, default=5, help="number of samples to evalulate"
250250
)
251-
251+
parser.add_argument(
252+
"-f",
253+
"--num_fewshot",
254+
type=int,
255+
default=None,
256+
metavar="N",
257+
help="Number of examples in few-shot context",
258+
)
252259
# Add additional args specific to eval via an ET Runner
253260
# Note: For initial integration, the tokenizer.model is also required
254261
parser.add_argument(
@@ -282,9 +289,10 @@ def eval_llama(
282289

283290
# Evaluate the model
284291
eval_results = evaluate_model(
285-
eval_wrapper,
286-
args.tasks, # pyre-ignore
287-
args.limit, # pyre-ignore
292+
eval_wrapper=eval_wrapper,
293+
tasks=args.tasks,
294+
num_fewshot=args.num_fewshot,
295+
limit=args.limit,
288296
)
289297

290298
for task, res in eval_results["results"].items():

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
)
1616

1717
from lm_eval.api.model import LM
18-
from lm_eval.evaluator import evaluate
18+
from lm_eval.evaluator import simple_evaluate
1919
from lm_eval.models.huggingface import HFLM as eval_wrapper
20-
from lm_eval.tasks import get_task_dict
2120

2221
from torch import nn
2322

@@ -85,6 +84,7 @@ def _model_generate(self, context, max_length, eos_token_id):
8584
def evaluate_model(
8685
eval_wrapper: LM,
8786
tasks: Optional[list] = None,
87+
num_fewshot: Optional[int] = None,
8888
limit: Optional[int] = None,
8989
) -> dict:
9090
"""
@@ -93,6 +93,7 @@ def evaluate_model(
9393
Args:
9494
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
9595
tasks: Optional[list]: The names of the evaluation tasks to perform.
96+
num_fewshot: Optional[int]: Number of examples in few-shot context.
9697
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
9798
9899
Returns:
@@ -107,11 +108,11 @@ def evaluate_model(
107108
tasks += list(
108109
lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore
109110
)
110-
task_dict = get_task_dict(tasks)
111111

112-
eval_results = evaluate(
113-
eval_wrapper,
114-
task_dict,
112+
eval_results = simple_evaluate(
113+
model=eval_wrapper,
114+
tasks=tasks,
115+
num_fewshot=num_fewshot,
115116
limit=limit,
116117
)
117118
return eval_results

0 commit comments

Comments
 (0)