Skip to content

Commit 39b71bb

Browse files
author
Rakshitha Ireddi
committed
feat: add MedReason environment and environment validation tests
- Add MedReason verifiers environment (closes #31) - Supports mixed MCQ and open-ended question evaluation - MCQ items graded via multiple_choice_accuracy - Open-ended items evaluated via LLM-as-Judge (JudgeRubric) - Configurable answer format (XML/boxed), shuffle, judge model - Add environment package validation test suite (original contribution) - Auto-discovers all 35 environment packages - Validates pyproject.toml structure, loader discoverability, load_environment presence, verifiers dependency - 7 pre-existing issues documented as xfail markers
1 parent d139be9 commit 39b71bb

File tree

3 files changed

+434
-0
lines changed

3 files changed

+434
-0
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import json
2+
from typing import Optional
3+
4+
import verifiers as vf
5+
from datasets import load_dataset
6+
from datasets.utils.logging import disable_progress_bar
7+
from medarc_verifiers.parsers.xml_parser import XMLParser
8+
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
9+
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
10+
from medarc_verifiers.utils import default_judge_api_key, judge_sampling_args_and_headers
11+
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
12+
from openai import AsyncOpenAI
13+
from verifiers.types import Info, State
14+
from verifiers.utils.data_utils import BOXED_SYSTEM_PROMPT, THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer
15+
16+
disable_progress_bar()
17+
18+
MCQ_QUESTION_TEMPLATE = """\
19+
Question: {question}
20+
Choices:
21+
{choices}
22+
Answer:"""
23+
24+
OPEN_QUESTION_TEMPLATE = """\
25+
{question}"""
26+
27+
JUDGE_TEMPLATE = """\
28+
You are evaluating an AI assistant's answer to a medical question.
29+
30+
<question>{question}</question>
31+
<reference_answer>{answer}</reference_answer>
32+
<assistant_answer>{response}</assistant_answer>
33+
34+
Is the assistant's answer medically equivalent to the reference answer?
35+
Consider synonyms, paraphrasing, and reasonable generalizations as correct.
36+
Answer [yes/no]."""
37+
38+
39+
def _parse_options(options_str: str | None) -> dict[str, str] | None:
40+
"""Parse the options field from the dataset.
41+
42+
The options field can be a JSON string representing a dict or list,
43+
or it can be empty/None for open-ended questions.
44+
"""
45+
if not options_str or options_str.strip() in ("", "None", "null", "{}"):
46+
return None
47+
try:
48+
parsed = json.loads(options_str)
49+
except (json.JSONDecodeError, TypeError):
50+
return None
51+
52+
if isinstance(parsed, dict):
53+
if not parsed:
54+
return None
55+
return {str(k): str(v) for k, v in parsed.items()}
56+
if isinstance(parsed, list):
57+
if not parsed:
58+
return None
59+
labels = [chr(ord("A") + i) for i in range(len(parsed))]
60+
return dict(zip(labels, [str(v) for v in parsed]))
61+
return None
62+
63+
64+
def _format_mcq_prompt(question: str, options: dict[str, str]) -> str:
65+
"""Format a multiple-choice question prompt."""
66+
choices = "\n".join(f"{k}. {v}" for k, v in options.items())
67+
return MCQ_QUESTION_TEMPLATE.format(question=question, choices=choices)
68+
69+
70+
def load_environment(
71+
use_think: bool = False,
72+
system_prompt: Optional[str] = None,
73+
shuffle_answers: bool = False,
74+
shuffle_seed: int | None = 1618,
75+
answer_format: AnswerFormat | str = AnswerFormat.XML,
76+
judge_model: str = "gpt-4o-mini",
77+
judge_base_url: str | None = None,
78+
judge_api_key: str | None = None,
79+
) -> vf.Environment:
80+
"""
81+
MedReason medical reasoning evaluation environment.
82+
83+
Supports both multiple-choice and open-ended questions from the MedReason
84+
dataset (UCSC-VLAA/MedReason). MCQ items are graded by accuracy; open-ended
85+
items use LLM-as-a-Judge evaluation.
86+
87+
Args:
88+
use_think: Enable chain-of-thought reasoning with <think> tags.
89+
system_prompt: Custom system prompt override.
90+
shuffle_answers: Shuffle MCQ answer options.
91+
shuffle_seed: Seed for deterministic answer shuffling.
92+
answer_format: Answer format (xml or boxed).
93+
judge_model: Model to use for LLM-as-judge evaluation.
94+
judge_base_url: Base URL for judge API.
95+
judge_api_key: API key for judge model.
96+
"""
97+
ds = load_dataset("UCSC-VLAA/MedReason", split="train")
98+
99+
# Set up judge for open-ended questions
100+
api_key = default_judge_api_key(judge_base_url) if judge_api_key is None else judge_api_key
101+
sampling_args, default_headers = judge_sampling_args_and_headers(judge_model, judge_base_url)
102+
judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key, default_headers=default_headers)
103+
judge_rubric = vf.JudgeRubric(
104+
judge_client=judge_client,
105+
judge_model=judge_model,
106+
judge_prompt="{question}",
107+
judge_sampling_args=sampling_args,
108+
)
109+
110+
def _map(ex, idx=None):
111+
question_text = ex["question"]
112+
answer_text = ex["answer"]
113+
options = _parse_options(ex.get("options"))
114+
115+
if options:
116+
# MCQ: find gold letter by matching answer text to options
117+
gold_letter = None
118+
for letter, opt_text in options.items():
119+
if opt_text.strip().lower() == answer_text.strip().lower():
120+
gold_letter = letter
121+
break
122+
123+
if gold_letter is None:
124+
# Answer is the letter itself
125+
candidate = answer_text.strip().upper()
126+
if candidate in options:
127+
gold_letter = candidate
128+
else:
129+
gold_letter = "A"
130+
131+
if shuffle_answers and gold_letter in options:
132+
options, gold_letter, _ = randomize_multiple_choice(
133+
options=options,
134+
answer_choice=gold_letter,
135+
seed=shuffle_seed,
136+
row_id=ex.get("id_in_dataset", idx),
137+
)
138+
139+
return {
140+
"question": _format_mcq_prompt(question_text, options),
141+
"answer": gold_letter,
142+
"info": {
143+
"is_mcq": True,
144+
"answer_text": options.get(gold_letter, answer_text),
145+
"dataset_name": ex.get("dataset_name", ""),
146+
**({} if not shuffle_answers else {"options": options}),
147+
},
148+
}
149+
else:
150+
# Open-ended question
151+
return {
152+
"question": OPEN_QUESTION_TEMPLATE.format(question=question_text),
153+
"answer": answer_text,
154+
"info": {
155+
"is_mcq": False,
156+
"dataset_name": ex.get("dataset_name", ""),
157+
"question_raw": question_text,
158+
},
159+
}
160+
161+
load_from_cache_file = not shuffle_answers
162+
eval_dataset = ds.map(
163+
_map,
164+
with_indices=True,
165+
remove_columns=ds.column_names,
166+
load_from_cache_file=load_from_cache_file,
167+
)
168+
169+
# Set up parser based on answer format
170+
answer_format = AnswerFormat(answer_format) if isinstance(answer_format, str) else answer_format
171+
if answer_format == AnswerFormat.XML:
172+
final_system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT)
173+
parser_fields = ["think", "answer"] if use_think else ["answer"]
174+
parser = XMLParser(fields=parser_fields, answer_field="answer")
175+
elif answer_format == AnswerFormat.BOXED:
176+
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
177+
final_system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT)
178+
else:
179+
raise ValueError(f"Unsupported answer format: {answer_format=}")
180+
181+
async def medreason_reward_func(
182+
completion,
183+
answer,
184+
info: Info,
185+
state: State,
186+
**kwargs,
187+
) -> float:
188+
"""Unified reward: accuracy for MCQ, LLM judge for open-ended."""
189+
is_mcq = info.get("is_mcq", False)
190+
191+
if is_mcq:
192+
parsed_answer = parser.parse_answer(completion) or ""
193+
answer_text_val = info.get("answer_text", None)
194+
is_correct = multiple_choice_accuracy(
195+
llm_answer=parsed_answer,
196+
answer_letter=answer,
197+
answer_text=answer_text_val,
198+
)
199+
return 1.0 if is_correct else 0.0
200+
else:
201+
# Open-ended: use LLM judge
202+
parsed = parser.parse(completion, last=True)
203+
model_answer = getattr(parsed, "answer", None)
204+
205+
if model_answer is not None:
206+
question_raw = info.get("question_raw", "")
207+
judge_prompt = JUDGE_TEMPLATE.format(
208+
question=question_raw,
209+
answer=answer,
210+
response=model_answer,
211+
)
212+
judge_response = await judge_rubric.judge(judge_prompt, model_answer, answer, state)
213+
judge_response_clean = judge_response.strip().lower()
214+
else:
215+
judge_response_clean = "no"
216+
judge_response = "no answer"
217+
218+
info.setdefault("judge_feedback", []).append(
219+
{
220+
"parsed": judge_response_clean,
221+
"raw_judge": str(judge_response),
222+
}
223+
)
224+
225+
if "yes" in judge_response_clean and "no" not in judge_response_clean:
226+
return 1.0
227+
else:
228+
return 0.0
229+
230+
judge_rubric.add_reward_func(medreason_reward_func, weight=1.0)
231+
232+
return vf.SingleTurnEnv(
233+
eval_dataset=eval_dataset,
234+
system_prompt=final_system_prompt,
235+
parser=parser,
236+
rubric=judge_rubric,
237+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[project]
2+
name = "medreason"
3+
version = "0.1.0"
4+
description = "MedReason medical reasoning evaluation with mixed MCQ and open-ended QA"
5+
readme = "README.md"
6+
requires-python = ">=3.11"
7+
dependencies = [
8+
"datasets>=4.0.0",
9+
"verifiers>=0.1.2.post0",
10+
"medarc_verifiers>=0.1.0",
11+
"openai",
12+
]
13+
14+
[build-system]
15+
requires = ["hatchling"]
16+
build-backend = "hatchling.build"
17+
18+
[tool.hatch.build]
19+
include = ["medreason.py"]
20+
21+
[tool.uv.sources]
22+
medarc_verifiers = { git = "https://github.com/MedARC-AI/med-lm-envs" }
23+
24+
[tool.prime.environment]
25+
loader = "medreason:load_environment"
26+
display_name = "MedReason"
27+
visibility = "PUBLIC"

0 commit comments

Comments
 (0)