Skip to content

Commit 37b1ff4

Browse files
committed
add agentclinic env
1 parent 432a5c6 commit 37b1ff4

File tree

1 file changed

+3
-35
lines changed

1 file changed

+3
-35
lines changed

environments/agentclinic/agentclinic.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
AgentClinic Environment for Prime Intellect Verifiers
33
Reimplemented to match the original paper exactly.
4-
Supports multiple LLM backends via OpenAI-compatible APIs
54
"""
65
from __future__ import annotations
76
from typing import Any, Dict, List, Optional
@@ -91,7 +90,6 @@ async def inference_patient(self, question: str) -> str:
9190
if not answer:
9291
answer = "I'm not sure about that."
9392

94-
# Update history like the paper
9593
self.agent_hist += question + "\n\n" + answer + "\n\n"
9694
return answer
9795

@@ -140,7 +138,6 @@ async def inference_measurement(self, question: str) -> str:
140138
if not answer:
141139
answer = "RESULTS: NORMAL READINGS"
142140

143-
# Update history
144141
self.agent_hist += question + "\n\n" + answer + "\n\n"
145142
return answer
146143

@@ -183,7 +180,6 @@ def __init__(self, scenario_dict: Dict[str, Any]):
183180
self.question = scenario_dict.get("question", "")
184181
self.image_url = scenario_dict.get("image_url", "")
185182

186-
# Extract correct answer from answers array
187183
answers = scenario_dict.get("answers", [])
188184
self.diagnosis = next((a["text"] for a in answers if a.get("correct")), "")
189185

@@ -237,7 +233,7 @@ def _compose_doctor_system(use_think: bool, max_infs: int, current_infs: int) ->
237233

238234
class AgentClinicEnv(vf.MultiTurnEnv):
239235
"""
240-
AgentClinic environment matching the paper's main() loop.
236+
AgentClinic environment
241237
Doctor is the evaluated model, Patient and Measurement are helper agents.
242238
"""
243239

@@ -247,15 +243,12 @@ def __init__(
247243
max_turns: int = 20,
248244
use_think: bool = False,
249245
name: str = "AgentClinic",
250-
# Patient agent config
251246
patient_model: str = "gpt-4o-mini",
252247
patient_base_url: Optional[str] = None,
253248
patient_api_key: Optional[str] = None,
254-
# Measurement agent config
255249
measurement_model: str = "gpt-4o-mini",
256250
measurement_base_url: Optional[str] = None,
257251
measurement_api_key: Optional[str] = None,
258-
# Moderator/judge config
259252
moderator_model: str = "gpt-4o-mini",
260253
moderator_base_url: Optional[str] = None,
261254
moderator_api_key: Optional[str] = None,
@@ -283,22 +276,18 @@ def __init__(
283276
self._max_turns = max_turns
284277
self._use_think = use_think
285278

286-
# Store patient agent LLM configuration
287279
self._patient_model = patient_model
288280
self._patient_base_url = patient_base_url
289281
self._patient_api_key = patient_api_key or os.environ.get("OPENAI_API_KEY")
290282

291-
# Store measurement agent LLM configuration
292283
self._measurement_model = measurement_model
293284
self._measurement_base_url = measurement_base_url
294285
self._measurement_api_key = measurement_api_key or os.environ.get("OPENAI_API_KEY")
295286

296-
# Store moderator LLM configuration
297287
self._moderator_model = moderator_model
298288
self._moderator_base_url = moderator_base_url
299289
self._moderator_api_key = moderator_api_key or os.environ.get("OPENAI_API_KEY")
300290

301-
# Create moderator client for scoring
302291
self._moderator_client = AsyncOpenAI(
303292
base_url=moderator_base_url,
304293
api_key=self._moderator_api_key
@@ -334,13 +323,11 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
334323
Override MultiTurnEnv.setup_state to initialize agents for each case.
335324
This is called by the rollout() method with the initial state.
336325
"""
337-
# Get the case index from info (passed through from dataset)
338326
info = state.get("info", {})
339327
case_index = info.get("case_id", 0)
340328

341329
scenario = self._scenarios[case_index]
342330

343-
# Create separate AsyncOpenAI clients for each agent
344331
patient_client = AsyncOpenAI(
345332
base_url=self._patient_base_url,
346333
api_key=self._patient_api_key
@@ -351,7 +338,7 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
351338
api_key=self._measurement_api_key
352339
)
353340

354-
# Create fresh agents for this case (like paper's initialization)
341+
# Create new agents for each case (like paper's initialization)
355342
patient_agent = PatientAgent(
356343
client=patient_client,
357344
model=self._patient_model,
@@ -381,7 +368,6 @@ async def is_completed(self, messages: vf.Messages, state: vf.State, info: Dict[
381368
"""Check if conversation is complete."""
382369
turns = state.get("turn", 0)
383370

384-
# Check last assistant message for DIAGNOSIS READY (like paper)
385371
last_assistant = None
386372
for m in reversed(messages):
387373
if isinstance(m, dict) and m.get("role") == "assistant":
@@ -405,32 +391,26 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
405391
new_state = dict(state)
406392
new_state["turn"] = state.get("turn", 0) + 1
407393

408-
# Get agents from state
409394
patient_agent = new_state["_patient_agent"]
410395
measurement_agent = new_state["_measurement_agent"]
411396

412-
# Get last doctor message
413397
doctor_dialogue = ""
414398
for m in reversed(messages):
415399
if isinstance(m, dict) and m.get("role") == "assistant":
416400
doctor_dialogue = m.get("content", "")
417401
break
418402

419-
# Final turn nudge
420403
if new_state["turn"] >= self._max_turns:
421404
return (
422405
[{"role": "user", "content": "This is the final question. Please provide a diagnosis.\nDIAGNOSIS READY: "}],
423406
new_state
424407
)
425408

426-
# Check if doctor requested test (like paper's main loop)
427409
if "REQUEST TEST" in doctor_dialogue:
428-
# Measurement agent responds
429410
result = await measurement_agent.inference_measurement(doctor_dialogue)
430411
patient_agent.agent_hist += result + "\n\n" # Add to patient history too
431412
return ([{"role": "user", "content": result}], new_state)
432413

433-
# Otherwise, patient responds
434414
pi_dialogue = await patient_agent.inference_patient(doctor_dialogue)
435415
measurement_agent.agent_hist += pi_dialogue + "\n\n" # Add to measurement history too
436416

@@ -448,7 +428,7 @@ async def compare_results_llm(
448428
moderator_model: str = "gpt-4o-mini"
449429
) -> float:
450430
"""
451-
LLM judge matching paper's compare_results().
431+
LLM judge
452432
Uses configurable moderator LLM.
453433
"""
454434
# Fallback to default OpenAI if no client provided
@@ -458,11 +438,9 @@ async def compare_results_llm(
458438
return 0.0
459439
moderator_client = AsyncOpenAI(api_key=api_key)
460440

461-
# Extract diagnosis from "DIAGNOSIS READY: [diagnosis]" format
462441
if "DIAGNOSIS READY:" in prediction:
463442
prediction = prediction.split("DIAGNOSIS READY:")[-1].strip()
464443

465-
# Also handle \boxed{} format for verifiers compatibility
466444
try:
467445
boxed = extract_boxed_answer(prediction)
468446
if boxed:
@@ -495,7 +473,6 @@ async def compare_results_llm(
495473
class AccuracyReward:
496474
"""Reward function class that holds reference to moderator client."""
497475

498-
# Add __name__ for verifiers framework compatibility
499476
__name__ = "accuracy_reward"
500477

501478
def __init__(self, moderator_client: AsyncOpenAI, moderator_model: str):
@@ -506,7 +483,6 @@ async def __call__(self, prompt: str, completion: str, answer: str, state: Dict[
506483
"""Reward function for verifiers."""
507484
gold = (state.get("info") or {}).get("gold", "") or answer
508485

509-
# Extract final diagnosis from completion
510486
if isinstance(completion, list):
511487
completion_text = ""
512488
for msg in reversed(completion):
@@ -542,15 +518,12 @@ def load_environment(
542518
dataset_type: Optional[str] = None,
543519
use_think: bool = False,
544520
max_turns: int = 20,
545-
# Patient agent config
546521
patient_model: str = "gpt-4o-mini",
547522
patient_base_url: Optional[str] = None,
548523
patient_api_key: Optional[str] = None,
549-
# Measurement agent config
550524
measurement_model: str = "gpt-4o-mini",
551525
measurement_base_url: Optional[str] = None,
552526
measurement_api_key: Optional[str] = None,
553-
# Moderator config
554527
moderator_model: str = "gpt-4o-mini",
555528
moderator_base_url: Optional[str] = None,
556529
moderator_api_key: Optional[str] = None,
@@ -578,7 +551,6 @@ def load_environment(
578551
Returns:
579552
AgentClinic environment instance
580553
"""
581-
# Find dataset file
582554
if dataset_path:
583555
# User specified a path via --env-args
584556
# Check if it's an absolute path
@@ -615,7 +587,6 @@ def load_environment(
615587
if not cases:
616588
raise ValueError(f"No cases loaded from: {found}")
617589

618-
# Auto-detect dataset type if not specified
619590
if dataset_type is None:
620591
dataset_type = _detect_dataset_type(cases)
621592

@@ -630,7 +601,6 @@ def load_environment(
630601
else:
631602
raise ValueError(f"Unknown dataset type: {dataset_type}. Use 'medqa' or 'nejm'")
632603

633-
# Create environment
634604
env = AgentClinicEnv(
635605
cases=cases,
636606
max_turns=max_turns,
@@ -648,10 +618,8 @@ def load_environment(
648618
**kwargs,
649619
)
650620

651-
# Override scenarios with typed versions
652621
env._scenarios = scenarios
653622

654-
# Set rubric with moderator client from environment
655623
accuracy_reward_func = AccuracyReward(env._moderator_client, env._moderator_model)
656624
env.rubric = vf.Rubric(
657625
funcs=[accuracy_reward_func],

0 commit comments

Comments
 (0)