Skip to content

Commit 4ef2cd4

Browse files
committed
Add MedQA dataset
1 parent ab8bfd9 commit 4ef2cd4

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

src/med_reason_evals/data/medqa.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Load and process the MedQA dataset.
2+
3+
Dataset: HuggingFace `GBaker/MedQA-USMLE-4-options` dataset.
4+
Each example is normalized to the following fields:
5+
{
6+
"question": "<question + formatted options>", # string used as the user prompt
7+
"answer": "<A|B|C|D>", # top-level gold letter
8+
"info": { ...original example fields... } # full source row for debugging
9+
}
10+
"""
11+
12+
from typing import Any
13+
14+
from datasets import load_dataset
15+
16+
17+
class MedQADataset:
18+
"""Process the MedQA dataset."""
19+
20+
def __init__(
21+
self,
22+
num_train_examples: int = -1,
23+
num_test_examples: int = -1,
24+
):
25+
"""Initialize the MedQA dataset processor.
26+
27+
Args:
28+
num_train_examples: Number of training examples to use (-1 for all)
29+
num_test_examples: Number of test examples to use (-1 for all)
30+
"""
31+
self.num_train_examples = num_train_examples
32+
self.num_test_examples = num_test_examples
33+
self.rng_seed = 12345
34+
35+
# Load and process datasets on initialization
36+
self.train_ds, self.test_ds = self._load_and_process_datasets()
37+
38+
def _load_and_process_datasets(self) -> tuple:
39+
"""Load and process the MedQA datasets."""
40+
# Load the raw datasets
41+
ds = load_dataset("GBaker/MedQA-USMLE-4-options")
42+
train_raw = ds["train"]
43+
test_raw = ds["test"]
44+
45+
# Limit number of examples if specified
46+
if self.num_train_examples != -1:
47+
train_raw = train_raw.select(
48+
range(min(self.num_train_examples, len(train_raw)))
49+
)
50+
if self.num_test_examples != -1:
51+
test_raw = test_raw.select(
52+
range(min(self.num_test_examples, len(test_raw)))
53+
)
54+
55+
# Format datasets for verifiers
56+
train_formatted = self._format_for_verifiers(train_raw, "train")
57+
test_formatted = self._format_for_verifiers(test_raw, "test")
58+
59+
# Shuffle datasets
60+
train_formatted = train_formatted.shuffle(seed=self.rng_seed)
61+
test_formatted = test_formatted.shuffle(seed=self.rng_seed)
62+
63+
return train_formatted, test_formatted
64+
65+
def _format_for_verifiers(self, dataset: Any, split: str) -> Any:
66+
"""Format dataset for verifiers with question, answer, and info fields."""
67+
valid = {"A", "B", "C", "D"}
68+
69+
def format_row(row: dict) -> dict:
70+
row = dict(row)
71+
72+
# Build the user-visible question string (question + options)
73+
q = row.get("question", "") or ""
74+
opts = row.get("options", {}) or {}
75+
76+
question_str = f"Question: {q}\n"
77+
for k, v in opts.items():
78+
# Skip null or empty values
79+
if v is not None and v != "":
80+
question_str += f"\n{k}. {v}"
81+
82+
# Lift the answer top-level, normalize to a single letter
83+
ans = (row.get("answer_idx") or "").strip().upper()
84+
if ans not in valid:
85+
# Final guard: set to empty if unexpected
86+
ans = ""
87+
88+
# Keep full original example under 'info'
89+
info = dict(row)
90+
91+
return {
92+
"question": question_str,
93+
"answer": ans,
94+
"info": info,
95+
}
96+
97+
return dataset.map(format_row)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""MedQA Evaluation.
2+
3+
Dataset: HuggingFace `GBaker/MedQA-USMLE-4-options` dataset.
4+
5+
- Parser: Extracts \\boxed{A|B|C|D} from completions
6+
- Reward Functions:
7+
- Correct answer reward
8+
- Format reward
9+
"""
10+
11+
import os
12+
import verifiers as vf
13+
from dotenv import load_dotenv
14+
from openai import OpenAI
15+
from verifiers.utils.data_utils import extract_boxed_answer
16+
17+
from med_reason_evals.data.medqa import MedQADataset
18+
from med_reason_evals.verifiers.exact_match_reward import (
19+
exact_match_reward_func,
20+
)
21+
22+
23+
def load_environment(
24+
use_think: bool = True,
25+
num_train_examples: int = -1,
26+
num_test_examples: int = -1,
27+
) -> vf.SingleTurnEnv:
28+
"""MedQA-USMLE-4-options multiple-choice evaluation.
29+
30+
Args:
31+
use_think: Whether to require step-by-step reasoning (default: True)
32+
num_train_examples: Number of training examples to use (-1 for all)
33+
num_test_examples: Number of test examples to use (-1 for all)
34+
35+
Returns:
36+
vf.SingleTurnEnv configured with MedQA dataset
37+
"""
38+
dataset = MedQADataset(
39+
num_train_examples=num_train_examples,
40+
num_test_examples=num_test_examples,
41+
)
42+
43+
options = "(A, B, C, or D)" # MedQA has 4 options
44+
45+
system_prompt = (
46+
f"Think step-by-step inside <think> tags, then give only the letter "
47+
f"of the correct answer inside \\boxed{{...}} {options}. Do not include option "
48+
f"text in the box; only the letter."
49+
)
50+
51+
parser = vf.ThinkParser(extract_fn=extract_boxed_answer)
52+
53+
rubric = vf.Rubric(
54+
funcs=[exact_match_reward_func, parser.get_format_reward_func()],
55+
weights=[1.0, 0.0],
56+
parser=parser,
57+
)
58+
59+
return vf.SingleTurnEnv(
60+
dataset=dataset.train_ds,
61+
eval_dataset=dataset.test_ds,
62+
system_prompt=system_prompt,
63+
parser=parser,
64+
rubric=rubric,
65+
)
66+
67+
68+
def main() -> None:
69+
"""Run the evaluation on the MedQA dataset."""
70+
# Load environment variables
71+
load_dotenv()
72+
73+
# Load environment
74+
env = load_environment(
75+
use_think=True,
76+
num_train_examples=-1,
77+
num_test_examples=-1,
78+
)
79+
80+
# Initialize OpenAI-compatible client (e.g., Groq)
81+
client = OpenAI(
82+
api_key=os.getenv("GROQ_API_KEY"),
83+
base_url="https://api.groq.com/openai/v1", # Fixed URL (removed extra spaces)
84+
)
85+
86+
# Run evaluation
87+
results = env.evaluate(
88+
client=client,
89+
model="llama-3.3-70b-versatile",
90+
num_examples=2,
91+
rollouts_per_example=5,
92+
)
93+
print(results)
94+
95+
96+
if __name__ == "__main__":
97+
main()

0 commit comments

Comments
 (0)