Skip to content

Commit 3160b02

Browse files
committed
Add MedBullets dataset
1 parent 68b4825 commit 3160b02

File tree

3 files changed

+228
-0
lines changed

3 files changed

+228
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Load and process the Medbullets dataset.
2+
3+
Dataset: HuggingFace `mkieffer/Medbullets` dataset.
4+
Each example is normalized to the fields expected by `vf.Verifiers`:
5+
{
6+
"question": "<stem + formatted options>", # string used as the user prompt
7+
"answer": "<A|B|C|D|E>", # top-level gold letter
8+
"info": { ...original example fields... } # full source row for debugging
9+
}
10+
11+
- num_options=4 : loads splits `op4_train` / `op4_eval` and drops option "E"
12+
- num_options=5 : loads splits `op5_train` / `op5_eval`
13+
"""
14+
15+
from typing import Any
16+
17+
from datasets import load_dataset
18+
19+
20+
class MedBulletsDataset:
21+
"""Process the MedBullets dataset."""
22+
23+
def __init__(
24+
self,
25+
num_train_examples: int = -1,
26+
num_eval_examples: int = -1,
27+
num_options: int = 4,
28+
):
29+
"""Initialize the MedBullets dataset processor.
30+
31+
Args:
32+
num_train_examples: Number of training examples to use (-1 for all)
33+
num_eval_examples: Number of evaluation examples to use (-1 for all)
34+
num_options: Number of options per question (4 or 5)
35+
"""
36+
if num_options not in [4, 5]:
37+
raise ValueError("'num_options' must be 4 or 5")
38+
39+
self.num_train_examples = num_train_examples
40+
self.num_eval_examples = num_eval_examples
41+
self.num_options = num_options
42+
self.rng_seed = 12345
43+
44+
# Load and process datasets on initialization
45+
self.train_ds, self.eval_ds = self._load_and_process_datasets()
46+
47+
def _load_and_process_datasets(self) -> tuple:
48+
"""Load and process the MedBullets datasets."""
49+
# Load the raw datasets based on number of options
50+
if self.num_options == 4:
51+
train_raw, eval_raw = load_dataset(
52+
"mkieffer/Medbullets", split=["op4_train", "op4_eval"]
53+
)
54+
# Remove option E from 4-option datasets
55+
train_raw = self._remove_option_e(train_raw)
56+
eval_raw = self._remove_option_e(eval_raw)
57+
else: # num_options == 5
58+
train_raw, eval_raw = load_dataset(
59+
"mkieffer/Medbullets", split=["op5_train", "op5_eval"]
60+
)
61+
62+
# Limit number of examples if specified
63+
if self.num_train_examples != -1:
64+
train_raw = train_raw.select(
65+
range(min(self.num_train_examples, len(train_raw)))
66+
)
67+
if self.num_eval_examples != -1:
68+
eval_raw = eval_raw.select(
69+
range(min(self.num_eval_examples, len(eval_raw)))
70+
)
71+
72+
# Format datasets for verifiers
73+
train_formatted = self._format_for_verifiers(train_raw, "train")
74+
eval_formatted = self._format_for_verifiers(eval_raw, "eval")
75+
76+
# Shuffle datasets
77+
train_formatted = train_formatted.shuffle(seed=self.rng_seed)
78+
eval_formatted = eval_formatted.shuffle(seed=self.rng_seed)
79+
80+
return train_formatted, eval_formatted
81+
82+
def _remove_option_e(self, dataset: Any) -> Any:
83+
"""Remove option E from the dataset."""
84+
85+
def remove_e(ex: dict) -> dict:
86+
ex = dict(ex)
87+
ex["options"] = {k: v for k, v in ex["options"].items() if k != "E"}
88+
return ex
89+
90+
return dataset.map(remove_e)
91+
92+
def _format_for_verifiers(self, dataset: Any, split: str) -> Any:
93+
"""Format dataset for verifiers with question, answer, and info fields."""
94+
valid = {"A", "B", "C", "D", "E"}
95+
96+
def format_row(row: dict) -> dict:
97+
row = dict(row)
98+
99+
# Build the user-visible question string (stem + options)
100+
q = row.get("question", "") or ""
101+
opts = row.get("options", {}) or {}
102+
103+
question_str = f"Question: {q}\n"
104+
for k, v in opts.items():
105+
# Skip null values of v (for the combined dataset where E
106+
# opt for 4op is null)
107+
if v is not None and v != "":
108+
question_str += f"\n{k}: {v}"
109+
110+
# Lift the answer top-level, normalize to a single letter
111+
ans = (row.get("answer") or "").strip().upper()
112+
if ans not in valid:
113+
# If op4 split sometimes stores 'E' or empty, coerce safely
114+
if ans == "" and "answer_letter" in row:
115+
ans = str(row["answer_letter"]).strip().upper()
116+
if ans not in valid:
117+
# Final guard: set to empty if unexpected
118+
ans = ""
119+
120+
# Keep full original example under 'info'
121+
info = dict(row)
122+
123+
return {
124+
"question": question_str,
125+
"answer": ans,
126+
"info": info,
127+
}
128+
129+
return dataset.map(format_row)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Reward Functions for the Medbullets dataset.
2+
3+
- Parser extracts \\boxed{A|B|C|D|E} from completions
4+
- Reward looks for exact match between parsed letter and answer letter
5+
"""
6+
7+
from typing import Any
8+
9+
10+
def correct_answer_reward_func(
11+
parser: Any, completion: str, answer: str, **kwargs
12+
) -> float:
13+
"""Reward function for correct answer.
14+
15+
Args:
16+
parser: Parser object
17+
completion: Completion string
18+
answer: Answer string
19+
**kwargs: Additional keyword arguments
20+
Returns:
21+
float: Reward value
22+
"""
23+
response = parser.parse_answer(completion) or ""
24+
25+
return 1.0 if response == answer else 0.0
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Medbullets Evaluation.
2+
3+
Dataset: HuggingFace `mkieffer/Medbullets` dataset.
4+
5+
- Parser: Extracts \boxed{A|B|C|D|E} from completions
6+
- Reward Functions:
7+
- Correct answer reward
8+
- Format reward
9+
"""
10+
11+
import os
12+
13+
import verifiers as vf
14+
from dotenv import load_dotenv
15+
from openai import OpenAI
16+
from verifiers.utils.data_utils import extract_boxed_answer
17+
18+
from med_reason_evals.data.medbullets import MedBulletsDataset
19+
from med_reason_evals.verifiers.answer_correctness_reward import (
20+
correct_answer_reward_func,
21+
)
22+
23+
24+
def main() -> None:
25+
"""Run the evaluation on the Medbullets dataset."""
26+
# Load environment variables
27+
load_dotenv()
28+
29+
# Create an instance of the processor
30+
dataset = MedBulletsDataset(
31+
num_train_examples=-1, num_eval_examples=-1, num_options=4
32+
)
33+
34+
# Construct prompts
35+
options = "(A, B, C, or D)" if dataset.num_options == 4 else "(A, B, C, D, or E)"
36+
37+
system_prompt = (
38+
f"Think step-by-step inside think> tags, then give only the letter "
39+
f"of the correct answer inside \\boxed{{...}} {options}. Do not include option "
40+
f"text in the box; only the letter."
41+
)
42+
43+
parser = vf.ThinkParser(extract_fn=extract_boxed_answer)
44+
45+
rubric = vf.Rubric(
46+
funcs=[correct_answer_reward_func, parser.get_format_reward_func()],
47+
weights=[1.0, 0.0],
48+
parser=parser,
49+
)
50+
51+
env = vf.SingleTurnEnv(
52+
dataset=dataset.train_ds,
53+
eval_dataset=dataset.eval_ds,
54+
system_prompt=system_prompt,
55+
parser=parser,
56+
rubric=rubric,
57+
)
58+
59+
# Run the evaluation
60+
client = OpenAI(
61+
api_key=os.getenv("GROQ_API_KEY"),
62+
base_url="https://api.groq.com/v1",
63+
)
64+
results = env.evaluate(
65+
client=client,
66+
model="llama-3.3-70b-versatile",
67+
num_examples=2,
68+
rollouts_per_example=5,
69+
)
70+
print(results)
71+
72+
73+
if __name__ == "__main__":
74+
main()

0 commit comments

Comments
 (0)