Skip to content

Commit ad12276

Browse files
committed
feat: introduce results attribute on MMLU evaluator
In order to test the validity of our MMLU results or get information on prior runs, we need to be able to access the full set of results from the lm_eval.evaluator.simple_evaluate API. This commit provides that ability by adding a results attribute on the MMLUEvaluator class and storing the results there. Signed-off-by: Oleg S <[email protected]>
1 parent fd78adf commit ad12276

File tree

3 files changed

+95
-19
lines changed

3 files changed

+95
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.4.2
22

33
* Adds the ability to provide a custom system prompt to the MMLU-based evaluators. When a system prompt is provided, LM-eval applies the chat template under the hood, else it will pass the model a barebones prompt.
4+
* Adds an `extra_args` parameter to the `.run` method of all MMLU-based evaluators. This way, consumers are able to directly pass any additional arguments they want through to the `lm_eval.evaluators.simple_evaluate` function.
45

56
## 0.4
67

scripts/test_mmlu.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,41 @@
1+
# Standard
2+
from typing import Dict, List, Tuple, TypedDict
3+
14
# First Party
25
from instructlab.eval.mmlu import MMLUEvaluator
36

47
SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."""
58

69

10+
class MMLUSample(TypedDict):
11+
"""
12+
Example of a single sample returned from lm_eval when running MMLU.
13+
This is not a comprehensive type, just the subset of fields we care about for this test.
14+
"""
15+
16+
# Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples.
17+
# They will not be present with few_shot=0
18+
arguments: List[Tuple[str, str]]
19+
20+
21+
def all_samples_contain_system_prompt(
22+
samples: Dict[str, List[MMLUSample]], prompt: str
23+
) -> bool:
24+
"""
25+
Given a mapping of evaluation --> list of results, validates that all few-shot examples
26+
included the system prompt
27+
"""
28+
for topic, samples_set in samples.items():
29+
for sample in samples_set:
30+
for mmlu_prompt, _ in sample["arguments"]:
31+
if prompt not in mmlu_prompt:
32+
# we are looking for the exact system prompt, so no need to convert to normalize to lowercase
33+
print(f"found a sample in the '{topic}' MMLU topic set")
34+
return False
35+
36+
return True
37+
38+
739
def test_minimal_mmlu():
840
print("===> Executing 'test_minimal_mmlu'...")
941
try:
@@ -14,9 +46,28 @@ def test_minimal_mmlu():
1446
tasks=tasks,
1547
system_prompt=SYSTEM_PROMPT,
1648
)
17-
overall_score, individual_scores = mmlu.run()
49+
overall_score, individual_scores = mmlu.run(
50+
extra_args={"log_samples": True, "write_out": True}
51+
)
52+
samples = mmlu.results["samples"]
53+
1854
print(overall_score)
1955
print(individual_scores)
56+
57+
# we need n-shots > 1 to be able to validate the inclusion of the system prompt
58+
eligible_samples = {
59+
topic: samples[topic]
60+
for topic, shot in mmlu.results["n-shot"].items()
61+
if shot > 1
62+
}
63+
if eligible_samples:
64+
if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT):
65+
return False
66+
else:
67+
print(
68+
"MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..."
69+
)
70+
2071
except Exception as exc:
2172
print(f"'test_minimal_mmlu' failed: {exc}")
2273
return False

src/instructlab/eval/mmlu.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
"""
88

99
# Standard
10-
from typing import Optional, Union
10+
from typing import Any, Dict, Optional, Union
1111
import os
1212

1313
# Third Party
14-
from lm_eval.evaluator import simple_evaluate # type: ignore
15-
from lm_eval.tasks import TaskManager # type: ignore
14+
from lm_eval.evaluator import simple_evaluate
15+
from lm_eval.tasks import TaskManager
1616
import torch
1717

1818
# First Party
@@ -103,6 +103,7 @@ class AbstractMMLUEvaluator(Evaluator):
103103
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
104104
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
105105
system_prompt system prompt to be used when applying the chat template
106+
results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run.
106107
"""
107108

108109
def __init__(
@@ -124,18 +125,33 @@ def __init__(
124125
self.few_shots = few_shots
125126
self.batch_size = batch_size
126127
self.device = device
128+
self._results = None
127129

128-
def run(self, server_url: str | None = None) -> tuple:
130+
@property
131+
def results(self) -> Dict[str, Any] | None:
132+
"""
133+
Returns the results of the last MMLU evaluation, if one has taken place.
134+
135+
Returns:
136+
Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate`
137+
"""
138+
return self._results
139+
140+
def run(
141+
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
142+
) -> tuple:
129143
"""
130144
Runs evaluation
131145
132146
Attributes
133147
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
148+
extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function.
134149
135150
Returns:
136151
overall_score Average score for the task group
137152
individual_scores Individual scores for each task in the task group
138153
"""
154+
extra_args = {} if not extra_args else extra_args
139155
logger.debug(locals())
140156

141157
# TODO: make this a parameter for class?
@@ -156,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple:
156172

157173
return overall_score, individual_scores
158174

159-
def _run_mmlu(self, server_url: str | None = None) -> dict:
175+
def _run_mmlu(
176+
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
177+
) -> dict:
178+
extra_args = {} if not extra_args else extra_args
160179
if server_url is not None:
161180
# Requires lm_eval >= 0.4.4
162181
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
@@ -172,19 +191,24 @@ def _run_mmlu(self, server_url: str | None = None) -> dict:
172191
raise InvalidTasksDirError(self.tasks_dir)
173192
tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir)
174193
should_apply_chat_template = self.system_prompt is not None
175-
mmlu_output = self._simple_evaluate_with_error_handling(
176-
model=model,
177-
model_args=model_args,
178-
tasks=self.tasks,
179-
num_fewshot=self.few_shots,
180-
batch_size=self.batch_size,
181-
device=self.device,
182-
task_manager=tm,
183-
system_instruction=self.system_prompt,
184-
apply_chat_template=should_apply_chat_template,
185-
)
186-
results = mmlu_output["results"]
187-
return results
194+
195+
# configure the args here so users can override them as necessary
196+
simple_evaluate_kwargs = {
197+
"model": model,
198+
"model_args": model_args,
199+
"tasks": self.tasks,
200+
"num_fewshot": self.few_shots,
201+
"batch_size": self.batch_size,
202+
"device": self.device,
203+
"task_manager": tm,
204+
"system_instruction": self.system_prompt,
205+
"apply_chat_template": should_apply_chat_template,
206+
}
207+
simple_evaluate_kwargs.update(extra_args)
208+
209+
results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs)
210+
self._results = results
211+
return results["results"]
188212

189213
# This method converts general errors from simple_evaluate
190214
# into a more user-understandable error

0 commit comments

Comments
 (0)