Skip to content

Commit fef34de

Browse files
AymaneAymane
authored andcommitted
refactor with judgerubric
1 parent d54b6e2 commit fef34de

File tree

2 files changed

+94
-88
lines changed

2 files changed

+94
-88
lines changed

environments/agentclinic/README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ uv run --active -m verifiers.scripts.eval \
109109

110110
### Agent Parameters
111111

112-
Each agent (patient, measurement, moderator) can be configured via `--env-args`:
112+
Each agent (patient, measurement) can be configured via `--env-args`:
113113

114114
### Datasets
115115

@@ -121,10 +121,11 @@ Each agent (patient, measurement, moderator) can be configured via `--env-args`:
121121
- `max_turns`: Maximum conversation turns (default: 20)
122122
- `use_think`: Enable chain-of-thought prompting (default: false)
123123

124-
## Agent Roles
125124

126-
- **Doctor** (evaluated): Asks questions, orders tests, makes diagnosis
127-
- **Patient** (helper): Simulates patient responses
128-
- **Measurement** (helper): Returns test results
129-
- **Moderator** (helper): Judges diagnosis accuracy
125+
### Agent Roles
126+
127+
- **Doctor** (evaluated model): Asks questions, requests tests (e.g., "REQUEST TEST: MRI_Brain_Spine"), makes diagnosis
128+
- **Patient** (auxiliary LLM): Simulates realistic patient responses based on case symptoms
129+
- **Measurement** (auxiliary LLM): Returns test results from scenario data when requested
130+
- **Moderator** (auxiliary LLM): Evaluates diagnosis accuracy using JudgeRubric
130131

environments/agentclinic/agentclinic.py

Lines changed: 87 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
2-
AgentClinic Environment for Prime Intellect Verifiers
3-
Reimplemented to match the original paper exactly.
2+
AgentClinic Environment
43
"""
54
from __future__ import annotations
65
from typing import Any, Dict, List, Optional
@@ -284,15 +283,6 @@ def __init__(
284283
self._measurement_base_url = measurement_base_url
285284
self._measurement_api_key = measurement_api_key or os.environ.get("OPENAI_API_KEY")
286285

287-
self._moderator_model = moderator_model
288-
self._moderator_base_url = moderator_base_url
289-
self._moderator_api_key = moderator_api_key or os.environ.get("OPENAI_API_KEY")
290-
291-
self._moderator_client = AsyncOpenAI(
292-
base_url=moderator_base_url,
293-
api_key=self._moderator_api_key
294-
)
295-
296286
# Build dataset for verifiers
297287
prompts = []
298288
infos = []
@@ -338,7 +328,6 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
338328
api_key=self._measurement_api_key
339329
)
340330

341-
# Create new agents for each case (like paper's initialization)
342331
patient_agent = PatientAgent(
343332
client=patient_client,
344333
model=self._patient_model,
@@ -377,8 +366,14 @@ async def is_completed(self, messages: vf.Messages, state: vf.State, info: Dict[
377366
if last_assistant and "DIAGNOSIS READY" in last_assistant:
378367
return True
379368

380-
# Max turns reached
381-
if turns >= self._max_turns:
369+
# Check if we just sent the final diagnosis prompt
370+
# If so allow one more turn for the model to respond
371+
if messages and isinstance(messages[-1], dict) and messages[-1].get("role") == "user":
372+
last_user_content = messages[-1].get("content", "")
373+
if "You have reached the maximum number of questions" in last_user_content:
374+
return False # Give model one more turn to provide diagnosis
375+
376+
if turns > self._max_turns:
382377
return True
383378

384379
return False
@@ -400,9 +395,18 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
400395
doctor_dialogue = m.get("content", "")
401396
break
402397

398+
# Warning when approaching max turns
399+
if new_state["turn"] == self._max_turns - 2:
400+
warning = "You have 2 questions remaining. Please start formulating your diagnosis."
401+
pi_dialogue = await patient_agent.inference_patient(doctor_dialogue)
402+
measurement_agent.agent_hist += pi_dialogue + "\n\n"
403+
combined_response = f"{pi_dialogue}\n\n[System: {warning}]"
404+
return ([{"role": "user", "content": combined_response}], new_state)
405+
406+
# Final turn force diagnosis
403407
if new_state["turn"] >= self._max_turns:
404408
return (
405-
[{"role": "user", "content": "This is the final question. Please provide a diagnosis.\nDIAGNOSIS READY: "}],
409+
[{"role": "user", "content": "You have reached the maximum number of questions. Based on the information gathered, please provide your diagnosis now using the format: DIAGNOSIS READY: [your diagnosis]"}],
406410
new_state
407411
)
408412

@@ -421,78 +425,46 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
421425
# Scoring (LLM Judge)
422426
# ============================================================
423427

424-
async def compare_results_llm(
425-
prediction: str,
426-
gold: str,
427-
moderator_client: Optional[AsyncOpenAI] = None,
428-
moderator_model: str = "gpt-4o-mini"
429-
) -> float:
430-
"""
431-
LLM judge
432-
Uses configurable moderator LLM.
433-
"""
434-
# Fallback to default OpenAI if no client provided
435-
if moderator_client is None:
436-
api_key = os.environ.get("OPENAI_API_KEY")
437-
if not api_key:
438-
return 0.0
439-
moderator_client = AsyncOpenAI(api_key=api_key)
428+
# Judge prompt template for medical diagnosis evaluation
429+
JUDGE_TEMPLATE = """\
430+
Your job is to evaluate whether a doctor's diagnosis is medically equivalent to the ground truth diagnosis.
440431
441-
if "DIAGNOSIS READY:" in prediction:
442-
prediction = prediction.split("DIAGNOSIS READY:")[-1].strip()
432+
You will be given:
433+
1. The ground truth diagnosis (answer)
434+
2. The doctor's predicted diagnosis (response)
443435
444-
try:
445-
boxed = extract_boxed_answer(prediction)
446-
if boxed:
447-
prediction = boxed
448-
except:
449-
pass
436+
Your task is to determine if the doctor's diagnosis is medically equivalent to the ground truth, even if worded differently.
450437
451-
system = "You are responsible for determining if the correct diagnosis and the doctor diagnosis are the same disease. Please respond only with Yes or No. Nothing else."
452-
user = f"\nHere is the correct diagnosis: {gold}\nHere was the doctor dialogue: {prediction}\nAre these the same?"
438+
Consider these guidelines:
439+
- Medical terms that refer to the same condition should be considered equivalent
440+
- Different levels of specificity may be acceptable
441+
- Spelling variations of medical terms should be considered equivalent
442+
- The core medical meaning should match, even if additional details vary
453443
454-
messages = [
455-
{"role": "system", "content": system},
456-
{"role": "user", "content": user}
457-
]
444+
Ground truth diagnosis: {answer}
458445
459-
try:
460-
response = await moderator_client.chat.completions.create(
461-
model=moderator_model,
462-
messages=messages,
463-
temperature=0.0,
464-
max_tokens=10
465-
)
466-
answer = (response.choices[0].message.content or "").lower()
467-
return 1.0 if "yes" in answer else 0.0
468-
except Exception as e:
469-
print(f"[compare_results_llm] Error: {e}")
470-
return 0.0
446+
Doctor's diagnosis: {response}
471447
448+
Is the doctor's diagnosis medically equivalent to the ground truth diagnosis?
449+
Respond with either "CORRECT" or "INCORRECT".
450+
""".strip()
472451

473-
class AccuracyReward:
474-
"""Reward function class that holds reference to moderator client."""
475452

476-
__name__ = "accuracy_reward"
477-
478-
def __init__(self, moderator_client: AsyncOpenAI, moderator_model: str):
479-
self.moderator_client = moderator_client
480-
self.moderator_model = moderator_model
481-
482-
async def __call__(self, prompt: str, completion: str, answer: str, state: Dict[str, Any]) -> float:
483-
"""Reward function for verifiers."""
484-
gold = (state.get("info") or {}).get("gold", "") or answer
453+
def extract_diagnosis(completion_text: str) -> str:
454+
"""
455+
Extract diagnosis from completion text, handling DIAGNOSIS READY and boxed formats.
456+
"""
457+
if "DIAGNOSIS READY:" in completion_text:
458+
completion_text = completion_text.split("DIAGNOSIS READY:")[-1].strip()
485459

486-
if isinstance(completion, list):
487-
completion_text = ""
488-
for msg in reversed(completion):
489-
if isinstance(msg, dict) and msg.get("role") == "assistant":
490-
completion_text = msg.get("content", "")
491-
break
492-
else:
493-
completion_text = str(completion)
460+
try:
461+
boxed = extract_boxed_answer(completion_text)
462+
if boxed:
463+
completion_text = boxed
464+
except:
465+
pass
494466

495-
return await compare_results_llm(completion_text, gold, self.moderator_client, self.moderator_model)
467+
return completion_text.strip()
496468

497469

498470
# ============================================================
@@ -546,7 +518,7 @@ def load_environment(
546518
moderator_model: Model name for Moderator/Judge
547519
moderator_base_url: API base URL for moderator
548520
moderator_api_key: API key for moderator
549-
**kwargs: Additional arguments passed to AgentClinicEnv
521+
**kwargs: Additional arguments
550522
551523
Returns:
552524
AgentClinic environment instance
@@ -620,12 +592,45 @@ def load_environment(
620592

621593
env._scenarios = scenarios
622594

623-
accuracy_reward_func = AccuracyReward(env._moderator_client, env._moderator_model)
624-
env.rubric = vf.Rubric(
625-
funcs=[accuracy_reward_func],
626-
names=["accuracy_reward"]
595+
moderator_api_key = moderator_api_key or os.environ.get("OPENAI_API_KEY")
596+
judge_client = AsyncOpenAI(base_url=moderator_base_url, api_key=moderator_api_key) if moderator_api_key else None
597+
598+
rubric = vf.JudgeRubric(
599+
judge_client=judge_client,
600+
judge_model=moderator_model,
601+
judge_prompt=JUDGE_TEMPLATE,
627602
)
628603

604+
async def diagnosis_reward_func(judge, prompt, completion, answer, state, **kwargs) -> float:
605+
"""
606+
Reward function that uses LLM judge to evaluate diagnosis
607+
"""
608+
if isinstance(completion, list):
609+
completion_text = ""
610+
for msg in reversed(completion):
611+
if isinstance(msg, dict) and msg.get("role") == "assistant":
612+
completion_text = msg.get("content", "")
613+
break
614+
else:
615+
completion_text = str(completion)
616+
617+
diagnosis_text = extract_diagnosis(completion_text)
618+
619+
gold = (state.get("info") or {}).get("gold", "") or answer
620+
621+
judge_response = await judge(prompt, diagnosis_text, gold, state, **kwargs)
622+
623+
judge_response_clean = judge_response.strip().upper()
624+
625+
if "CORRECT" in judge_response_clean and "INCORRECT" not in judge_response_clean:
626+
return 1.0
627+
else:
628+
return 0.0
629+
630+
rubric.add_reward_func(diagnosis_reward_func, weight=1.0)
631+
632+
env.rubric = rubric
633+
629634
return env
630635

631636

0 commit comments

Comments
 (0)