11"""
22AgentClinic Environment for Prime Intellect Verifiers
33Reimplemented to match the original paper exactly.
4- Supports multiple LLM backends via OpenAI-compatible APIs
54"""
65from __future__ import annotations
76from typing import Any , Dict , List , Optional
@@ -91,7 +90,6 @@ async def inference_patient(self, question: str) -> str:
9190 if not answer :
9291 answer = "I'm not sure about that."
9392
94- # Update history like the paper
9593 self .agent_hist += question + "\n \n " + answer + "\n \n "
9694 return answer
9795
@@ -140,7 +138,6 @@ async def inference_measurement(self, question: str) -> str:
140138 if not answer :
141139 answer = "RESULTS: NORMAL READINGS"
142140
143- # Update history
144141 self .agent_hist += question + "\n \n " + answer + "\n \n "
145142 return answer
146143
@@ -183,7 +180,6 @@ def __init__(self, scenario_dict: Dict[str, Any]):
183180 self .question = scenario_dict .get ("question" , "" )
184181 self .image_url = scenario_dict .get ("image_url" , "" )
185182
186- # Extract correct answer from answers array
187183 answers = scenario_dict .get ("answers" , [])
188184 self .diagnosis = next ((a ["text" ] for a in answers if a .get ("correct" )), "" )
189185
@@ -237,7 +233,7 @@ def _compose_doctor_system(use_think: bool, max_infs: int, current_infs: int) ->
237233
238234class AgentClinicEnv (vf .MultiTurnEnv ):
239235 """
240- AgentClinic environment matching the paper's main() loop.
236+ AgentClinic environment
241237 Doctor is the evaluated model, Patient and Measurement are helper agents.
242238 """
243239
@@ -247,15 +243,12 @@ def __init__(
247243 max_turns : int = 20 ,
248244 use_think : bool = False ,
249245 name : str = "AgentClinic" ,
250- # Patient agent config
251246 patient_model : str = "gpt-4o-mini" ,
252247 patient_base_url : Optional [str ] = None ,
253248 patient_api_key : Optional [str ] = None ,
254- # Measurement agent config
255249 measurement_model : str = "gpt-4o-mini" ,
256250 measurement_base_url : Optional [str ] = None ,
257251 measurement_api_key : Optional [str ] = None ,
258- # Moderator/judge config
259252 moderator_model : str = "gpt-4o-mini" ,
260253 moderator_base_url : Optional [str ] = None ,
261254 moderator_api_key : Optional [str ] = None ,
@@ -283,22 +276,18 @@ def __init__(
283276 self ._max_turns = max_turns
284277 self ._use_think = use_think
285278
286- # Store patient agent LLM configuration
287279 self ._patient_model = patient_model
288280 self ._patient_base_url = patient_base_url
289281 self ._patient_api_key = patient_api_key or os .environ .get ("OPENAI_API_KEY" )
290282
291- # Store measurement agent LLM configuration
292283 self ._measurement_model = measurement_model
293284 self ._measurement_base_url = measurement_base_url
294285 self ._measurement_api_key = measurement_api_key or os .environ .get ("OPENAI_API_KEY" )
295286
296- # Store moderator LLM configuration
297287 self ._moderator_model = moderator_model
298288 self ._moderator_base_url = moderator_base_url
299289 self ._moderator_api_key = moderator_api_key or os .environ .get ("OPENAI_API_KEY" )
300290
301- # Create moderator client for scoring
302291 self ._moderator_client = AsyncOpenAI (
303292 base_url = moderator_base_url ,
304293 api_key = self ._moderator_api_key
@@ -334,13 +323,11 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
334323 Override MultiTurnEnv.setup_state to initialize agents for each case.
335324 This is called by the rollout() method with the initial state.
336325 """
337- # Get the case index from info (passed through from dataset)
338326 info = state .get ("info" , {})
339327 case_index = info .get ("case_id" , 0 )
340328
341329 scenario = self ._scenarios [case_index ]
342330
343- # Create separate AsyncOpenAI clients for each agent
344331 patient_client = AsyncOpenAI (
345332 base_url = self ._patient_base_url ,
346333 api_key = self ._patient_api_key
@@ -351,7 +338,7 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
351338 api_key = self ._measurement_api_key
352339 )
353340
354- # Create fresh agents for this case (like paper's initialization)
341+ # Create new agents for each case (like paper's initialization)
355342 patient_agent = PatientAgent (
356343 client = patient_client ,
357344 model = self ._patient_model ,
@@ -381,7 +368,6 @@ async def is_completed(self, messages: vf.Messages, state: vf.State, info: Dict[
381368 """Check if conversation is complete."""
382369 turns = state .get ("turn" , 0 )
383370
384- # Check last assistant message for DIAGNOSIS READY (like paper)
385371 last_assistant = None
386372 for m in reversed (messages ):
387373 if isinstance (m , dict ) and m .get ("role" ) == "assistant" :
@@ -405,32 +391,26 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
405391 new_state = dict (state )
406392 new_state ["turn" ] = state .get ("turn" , 0 ) + 1
407393
408- # Get agents from state
409394 patient_agent = new_state ["_patient_agent" ]
410395 measurement_agent = new_state ["_measurement_agent" ]
411396
412- # Get last doctor message
413397 doctor_dialogue = ""
414398 for m in reversed (messages ):
415399 if isinstance (m , dict ) and m .get ("role" ) == "assistant" :
416400 doctor_dialogue = m .get ("content" , "" )
417401 break
418402
419- # Final turn nudge
420403 if new_state ["turn" ] >= self ._max_turns :
421404 return (
422405 [{"role" : "user" , "content" : "This is the final question. Please provide a diagnosis.\n DIAGNOSIS READY: " }],
423406 new_state
424407 )
425408
426- # Check if doctor requested test (like paper's main loop)
427409 if "REQUEST TEST" in doctor_dialogue :
428- # Measurement agent responds
429410 result = await measurement_agent .inference_measurement (doctor_dialogue )
430411 patient_agent .agent_hist += result + "\n \n " # Add to patient history too
431412 return ([{"role" : "user" , "content" : result }], new_state )
432413
433- # Otherwise, patient responds
434414 pi_dialogue = await patient_agent .inference_patient (doctor_dialogue )
435415 measurement_agent .agent_hist += pi_dialogue + "\n \n " # Add to measurement history too
436416
@@ -448,7 +428,7 @@ async def compare_results_llm(
448428 moderator_model : str = "gpt-4o-mini"
449429) -> float :
450430 """
451- LLM judge matching paper's compare_results().
431+ LLM judge
452432 Uses configurable moderator LLM.
453433 """
454434 # Fallback to default OpenAI if no client provided
@@ -458,11 +438,9 @@ async def compare_results_llm(
458438 return 0.0
459439 moderator_client = AsyncOpenAI (api_key = api_key )
460440
461- # Extract diagnosis from "DIAGNOSIS READY: [diagnosis]" format
462441 if "DIAGNOSIS READY:" in prediction :
463442 prediction = prediction .split ("DIAGNOSIS READY:" )[- 1 ].strip ()
464443
465- # Also handle \boxed{} format for verifiers compatibility
466444 try :
467445 boxed = extract_boxed_answer (prediction )
468446 if boxed :
@@ -495,7 +473,6 @@ async def compare_results_llm(
495473class AccuracyReward :
496474 """Reward function class that holds reference to moderator client."""
497475
498- # Add __name__ for verifiers framework compatibility
499476 __name__ = "accuracy_reward"
500477
501478 def __init__ (self , moderator_client : AsyncOpenAI , moderator_model : str ):
@@ -506,7 +483,6 @@ async def __call__(self, prompt: str, completion: str, answer: str, state: Dict[
506483 """Reward function for verifiers."""
507484 gold = (state .get ("info" ) or {}).get ("gold" , "" ) or answer
508485
509- # Extract final diagnosis from completion
510486 if isinstance (completion , list ):
511487 completion_text = ""
512488 for msg in reversed (completion ):
@@ -542,15 +518,12 @@ def load_environment(
542518 dataset_type : Optional [str ] = None ,
543519 use_think : bool = False ,
544520 max_turns : int = 20 ,
545- # Patient agent config
546521 patient_model : str = "gpt-4o-mini" ,
547522 patient_base_url : Optional [str ] = None ,
548523 patient_api_key : Optional [str ] = None ,
549- # Measurement agent config
550524 measurement_model : str = "gpt-4o-mini" ,
551525 measurement_base_url : Optional [str ] = None ,
552526 measurement_api_key : Optional [str ] = None ,
553- # Moderator config
554527 moderator_model : str = "gpt-4o-mini" ,
555528 moderator_base_url : Optional [str ] = None ,
556529 moderator_api_key : Optional [str ] = None ,
@@ -578,7 +551,6 @@ def load_environment(
578551 Returns:
579552 AgentClinic environment instance
580553 """
581- # Find dataset file
582554 if dataset_path :
583555 # User specified a path via --env-args
584556 # Check if it's an absolute path
@@ -615,7 +587,6 @@ def load_environment(
615587 if not cases :
616588 raise ValueError (f"No cases loaded from: { found } " )
617589
618- # Auto-detect dataset type if not specified
619590 if dataset_type is None :
620591 dataset_type = _detect_dataset_type (cases )
621592
@@ -630,7 +601,6 @@ def load_environment(
630601 else :
631602 raise ValueError (f"Unknown dataset type: { dataset_type } . Use 'medqa' or 'nejm'" )
632603
633- # Create environment
634604 env = AgentClinicEnv (
635605 cases = cases ,
636606 max_turns = max_turns ,
@@ -648,10 +618,8 @@ def load_environment(
648618 ** kwargs ,
649619 )
650620
651- # Override scenarios with typed versions
652621 env ._scenarios = scenarios
653622
654- # Set rubric with moderator client from environment
655623 accuracy_reward_func = AccuracyReward (env ._moderator_client , env ._moderator_model )
656624 env .rubric = vf .Rubric (
657625 funcs = [accuracy_reward_func ],
0 commit comments