Skip to content

Commit bccc277

Browse files
committed
Add MetaMedQA dataset
1 parent 3160b02 commit bccc277

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Load and process the MetaMedQA dataset.
2+
3+
Dataset: HuggingFace `maximegmd/MetaMedQA` dataset.
4+
Each example is normalized to the fields expected by `vf.Verifiers`:
5+
{
6+
"question": "<formatted question + options>", # string used as the user prompt
7+
"answer": "<A|B|C|D|E>", # top-level gold letter
8+
"info": { ...original example fields... } # full source row for debugging
9+
}
10+
"""
11+
12+
from typing import Any
13+
14+
from datasets import load_dataset
15+
16+
17+
class MetaMedQADataset:
18+
"""Process the MetaMedQA dataset."""
19+
20+
def __init__(
21+
self,
22+
split: str = "test",
23+
num_examples: int = -1,
24+
):
25+
"""Initialize the MetaMedQA dataset processor.
26+
27+
Args:
28+
split: Dataset split to use (train, validation, test)
29+
num_examples: Number of examples to use (-1 for all)
30+
"""
31+
self.split = split
32+
self.num_examples = num_examples
33+
self.rng_seed = 12345
34+
35+
# Load and process datasets on initialization
36+
self.dataset = self._load_and_process_dataset()
37+
38+
def _load_and_process_dataset(self) -> Any:
39+
"""Load and process the MetaMedQA dataset."""
40+
# Load the raw dataset
41+
raw_ds = load_dataset("maximegmd/MetaMedQA", split=self.split)
42+
43+
# Limit number of examples if specified
44+
if self.num_examples != -1:
45+
raw_ds = raw_ds.select(range(min(self.num_examples, len(raw_ds))))
46+
47+
# Format dataset for verifiers
48+
formatted_ds = self._format_for_verifiers(raw_ds)
49+
50+
# Shuffle dataset
51+
return formatted_ds.shuffle(seed=self.rng_seed)
52+
53+
def _build_prompt(self, question: str, options: dict) -> str:
54+
"""Build prompt with question and options."""
55+
opts = "\n".join(f"{k}. {v}" for k, v in options.items())
56+
letters = ", ".join(sorted(options.keys()))
57+
return (
58+
"You are a clinician. Choose exactly ONE option letter.\n\n"
59+
f"Question:\n{question}\n\n"
60+
f"Options:\n{opts}\n\n"
61+
f"Answer with ONLY the letter ({letters})."
62+
)
63+
64+
def _format_for_verifiers(self, dataset: Any) -> Any:
65+
"""Format dataset for verifiers with question, answer, and info fields."""
66+
valid = {"A", "B", "C", "D", "E"}
67+
68+
def format_row(row: dict) -> dict:
69+
row = dict(row)
70+
71+
q: str = row["question"]
72+
options: dict = row["options"]
73+
gold_text: str = row["answer"]
74+
75+
# Find the gold letter by matching the answer text with options
76+
gold_letter = None
77+
for k, v in options.items():
78+
if (v or "").strip().lower() == (gold_text or "").strip().lower():
79+
gold_letter = k
80+
break
81+
82+
# If we can't find a matching letter, return None to filter out
83+
if gold_letter is None or gold_letter not in valid:
84+
# Default to first option if no match found
85+
gold_letter = next(iter(options.keys()))
86+
87+
# Build the user-visible question string (question + options)
88+
question_str = self._build_prompt(q, options)
89+
90+
# Keep full original example under 'info'
91+
info = dict(row)
92+
93+
return {
94+
"question": question_str,
95+
"answer": gold_letter,
96+
"info": info,
97+
}
98+
99+
return dataset.map(format_row)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""MetaMedQA Evaluation.
2+
3+
Dataset: HuggingFace `maximegmd/MetaMedQA` dataset.
4+
5+
- Parser: Extracts first letter A-Z from completions
6+
- Reward Functions:
7+
- Correct answer reward
8+
- Format reward
9+
"""
10+
11+
import os
12+
from typing import Any
13+
14+
import verifiers as vf
15+
from dotenv import load_dotenv
16+
from openai import OpenAI
17+
18+
from med_reason_evals.data.metamedqa import MetaMedQADataset
19+
from med_reason_evals.verifiers.answer_correctness_reward import (
20+
correct_answer_reward_func,
21+
)
22+
23+
24+
class LetterParser:
25+
"""Parser that extracts the first letter (A-Z) from completions."""
26+
27+
def __init__(self) -> None:
28+
"""Initialize the LetterParser."""
29+
pass
30+
31+
def parse_answer(self, completion: Any) -> str:
32+
"""Parse the completion to extract the first letter A-Z."""
33+
text = self._get_text_from_completion(completion)
34+
return self._first_letter(text) or ""
35+
36+
def get_format_reward_func(self) -> Any:
37+
"""Return a format reward function (simple placeholder)."""
38+
39+
def format_reward(
40+
parser: Any, completion: str, answer: str, **kwargs: Any
41+
) -> float:
42+
# Basic format reward - just check if we were able to extract a letter
43+
parsed = self.parse_answer(completion)
44+
return 1.0 if parsed != "" else 0.0
45+
46+
return format_reward
47+
48+
def _get_text_from_completion(self, completion: Any) -> str:
49+
if isinstance(completion, str):
50+
return completion
51+
if isinstance(completion, list) and completion:
52+
last = completion[-1]
53+
if isinstance(last, dict):
54+
return str(last.get("content", ""))
55+
return str(last)
56+
return str(completion)
57+
58+
def _first_letter(self, text: str) -> str:
59+
t = (text or "").upper()
60+
for ch in t:
61+
if "A" <= ch <= "Z":
62+
return ch
63+
return ""
64+
65+
66+
def main() -> None:
67+
"""Run the evaluation on the MetaMedQA dataset."""
68+
# Load environment variables
69+
load_dotenv()
70+
71+
# Create an instance of the processor
72+
dataset = MetaMedQADataset(split="test", num_examples=-1)
73+
74+
# Construct prompts
75+
system_prompt = (
76+
"Think step-by-step inside think> tags, then give only the letter "
77+
"of the correct answer. Do not include option text; only the letter."
78+
)
79+
80+
parser = LetterParser()
81+
82+
rubric = vf.Rubric(
83+
funcs=[correct_answer_reward_func, parser.get_format_reward_func()],
84+
weights=[1.0, 0.0],
85+
parser=parser,
86+
)
87+
88+
env = vf.SingleTurnEnv(
89+
dataset=dataset.dataset,
90+
eval_dataset=dataset.dataset, # Using same dataset for both train and eval as in original
91+
system_prompt=system_prompt,
92+
parser=parser,
93+
rubric=rubric,
94+
)
95+
96+
# Run the evaluation
97+
client = OpenAI(
98+
api_key=os.getenv("GROQ_API_KEY"),
99+
base_url="https://api.groq.com/v1",
100+
)
101+
results = env.evaluate(
102+
client=client,
103+
model="llama-3.3-70b-versatile",
104+
num_examples=2,
105+
rollouts_per_example=5,
106+
)
107+
print(results)
108+
109+
110+
if __name__ == "__main__":
111+
main()

0 commit comments

Comments
 (0)