11"""
2- AgentClinic Environment for Prime Intellect Verifiers
3- Reimplemented to match the original paper exactly.
2+ AgentClinic Environment
43"""
54from __future__ import annotations
65from typing import Any , Dict , List , Optional
@@ -284,15 +283,6 @@ def __init__(
284283 self ._measurement_base_url = measurement_base_url
285284 self ._measurement_api_key = measurement_api_key or os .environ .get ("OPENAI_API_KEY" )
286285
287- self ._moderator_model = moderator_model
288- self ._moderator_base_url = moderator_base_url
289- self ._moderator_api_key = moderator_api_key or os .environ .get ("OPENAI_API_KEY" )
290-
291- self ._moderator_client = AsyncOpenAI (
292- base_url = moderator_base_url ,
293- api_key = self ._moderator_api_key
294- )
295-
296286 # Build dataset for verifiers
297287 prompts = []
298288 infos = []
@@ -338,7 +328,6 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
338328 api_key = self ._measurement_api_key
339329 )
340330
341- # Create new agents for each case (like paper's initialization)
342331 patient_agent = PatientAgent (
343332 client = patient_client ,
344333 model = self ._patient_model ,
@@ -377,8 +366,14 @@ async def is_completed(self, messages: vf.Messages, state: vf.State, info: Dict[
377366 if last_assistant and "DIAGNOSIS READY" in last_assistant :
378367 return True
379368
380- # Max turns reached
381- if turns >= self ._max_turns :
369+ # Check if we just sent the final diagnosis prompt
370+ # If so allow one more turn for the model to respond
371+ if messages and isinstance (messages [- 1 ], dict ) and messages [- 1 ].get ("role" ) == "user" :
372+ last_user_content = messages [- 1 ].get ("content" , "" )
373+ if "You have reached the maximum number of questions" in last_user_content :
374+ return False # Give model one more turn to provide diagnosis
375+
376+ if turns > self ._max_turns :
382377 return True
383378
384379 return False
@@ -400,9 +395,18 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
400395 doctor_dialogue = m .get ("content" , "" )
401396 break
402397
398+ # Warning when approaching max turns
399+ if new_state ["turn" ] == self ._max_turns - 2 :
400+ warning = "You have 2 questions remaining. Please start formulating your diagnosis."
401+ pi_dialogue = await patient_agent .inference_patient (doctor_dialogue )
402+ measurement_agent .agent_hist += pi_dialogue + "\n \n "
403+ combined_response = f"{ pi_dialogue } \n \n [System: { warning } ]"
404+ return ([{"role" : "user" , "content" : combined_response }], new_state )
405+
406+ # Final turn force diagnosis
403407 if new_state ["turn" ] >= self ._max_turns :
404408 return (
405- [{"role" : "user" , "content" : "This is the final question. Please provide a diagnosis. \n DIAGNOSIS READY: " }],
409+ [{"role" : "user" , "content" : "You have reached the maximum number of questions. Based on the information gathered, please provide your diagnosis now using the format: DIAGNOSIS READY: [your diagnosis] " }],
406410 new_state
407411 )
408412
@@ -421,78 +425,46 @@ async def env_response(self, messages: vf.Messages, state: vf.State, info: Dict[
421425# Scoring (LLM Judge)
422426# ============================================================
423427
424- async def compare_results_llm (
425- prediction : str ,
426- gold : str ,
427- moderator_client : Optional [AsyncOpenAI ] = None ,
428- moderator_model : str = "gpt-4o-mini"
429- ) -> float :
430- """
431- LLM judge
432- Uses configurable moderator LLM.
433- """
434- # Fallback to default OpenAI if no client provided
435- if moderator_client is None :
436- api_key = os .environ .get ("OPENAI_API_KEY" )
437- if not api_key :
438- return 0.0
439- moderator_client = AsyncOpenAI (api_key = api_key )
428+ # Judge prompt template for medical diagnosis evaluation
429+ JUDGE_TEMPLATE = """\
430+ Your job is to evaluate whether a doctor's diagnosis is medically equivalent to the ground truth diagnosis.
440431
441- if "DIAGNOSIS READY:" in prediction :
442- prediction = prediction .split ("DIAGNOSIS READY:" )[- 1 ].strip ()
432+ You will be given:
433+ 1. The ground truth diagnosis (answer)
434+ 2. The doctor's predicted diagnosis (response)
443435
444- try :
445- boxed = extract_boxed_answer (prediction )
446- if boxed :
447- prediction = boxed
448- except :
449- pass
436+ Your task is to determine if the doctor's diagnosis is medically equivalent to the ground truth, even if worded differently.
450437
451- system = "You are responsible for determining if the correct diagnosis and the doctor diagnosis are the same disease. Please respond only with Yes or No. Nothing else."
452- user = f"\n Here is the correct diagnosis: { gold } \n Here was the doctor dialogue: { prediction } \n Are these the same?"
438+ Consider these guidelines:
439+ - Medical terms that refer to the same condition should be considered equivalent
440+ - Different levels of specificity may be acceptable
441+ - Spelling variations of medical terms should be considered equivalent
442+ - The core medical meaning should match, even if additional details vary
453443
454- messages = [
455- {"role" : "system" , "content" : system },
456- {"role" : "user" , "content" : user }
457- ]
444+ Ground truth diagnosis: {answer}
458445
459- try :
460- response = await moderator_client .chat .completions .create (
461- model = moderator_model ,
462- messages = messages ,
463- temperature = 0.0 ,
464- max_tokens = 10
465- )
466- answer = (response .choices [0 ].message .content or "" ).lower ()
467- return 1.0 if "yes" in answer else 0.0
468- except Exception as e :
469- print (f"[compare_results_llm] Error: { e } " )
470- return 0.0
446+ Doctor's diagnosis: {response}
471447
448+ Is the doctor's diagnosis medically equivalent to the ground truth diagnosis?
449+ Respond with either "CORRECT" or "INCORRECT".
450+ """ .strip ()
472451
473- class AccuracyReward :
474- """Reward function class that holds reference to moderator client."""
475452
476- __name__ = "accuracy_reward"
477-
478- def __init__ (self , moderator_client : AsyncOpenAI , moderator_model : str ):
479- self .moderator_client = moderator_client
480- self .moderator_model = moderator_model
481-
482- async def __call__ (self , prompt : str , completion : str , answer : str , state : Dict [str , Any ]) -> float :
483- """Reward function for verifiers."""
484- gold = (state .get ("info" ) or {}).get ("gold" , "" ) or answer
453+ def extract_diagnosis (completion_text : str ) -> str :
454+ """
455+ Extract diagnosis from completion text, handling DIAGNOSIS READY and boxed formats.
456+ """
457+ if "DIAGNOSIS READY:" in completion_text :
458+ completion_text = completion_text .split ("DIAGNOSIS READY:" )[- 1 ].strip ()
485459
486- if isinstance (completion , list ):
487- completion_text = ""
488- for msg in reversed (completion ):
489- if isinstance (msg , dict ) and msg .get ("role" ) == "assistant" :
490- completion_text = msg .get ("content" , "" )
491- break
492- else :
493- completion_text = str (completion )
460+ try :
461+ boxed = extract_boxed_answer (completion_text )
462+ if boxed :
463+ completion_text = boxed
464+ except :
465+ pass
494466
495- return await compare_results_llm ( completion_text , gold , self . moderator_client , self . moderator_model )
467+ return completion_text . strip ( )
496468
497469
498470# ============================================================
@@ -546,7 +518,7 @@ def load_environment(
546518 moderator_model: Model name for Moderator/Judge
547519 moderator_base_url: API base URL for moderator
548520 moderator_api_key: API key for moderator
549- **kwargs: Additional arguments passed to AgentClinicEnv
521+ **kwargs: Additional arguments
550522
551523 Returns:
552524 AgentClinic environment instance
@@ -620,12 +592,45 @@ def load_environment(
620592
621593 env ._scenarios = scenarios
622594
623- accuracy_reward_func = AccuracyReward (env ._moderator_client , env ._moderator_model )
624- env .rubric = vf .Rubric (
625- funcs = [accuracy_reward_func ],
626- names = ["accuracy_reward" ]
595+ moderator_api_key = moderator_api_key or os .environ .get ("OPENAI_API_KEY" )
596+ judge_client = AsyncOpenAI (base_url = moderator_base_url , api_key = moderator_api_key ) if moderator_api_key else None
597+
598+ rubric = vf .JudgeRubric (
599+ judge_client = judge_client ,
600+ judge_model = moderator_model ,
601+ judge_prompt = JUDGE_TEMPLATE ,
627602 )
628603
604+ async def diagnosis_reward_func (judge , prompt , completion , answer , state , ** kwargs ) -> float :
605+ """
606+ Reward function that uses LLM judge to evaluate diagnosis
607+ """
608+ if isinstance (completion , list ):
609+ completion_text = ""
610+ for msg in reversed (completion ):
611+ if isinstance (msg , dict ) and msg .get ("role" ) == "assistant" :
612+ completion_text = msg .get ("content" , "" )
613+ break
614+ else :
615+ completion_text = str (completion )
616+
617+ diagnosis_text = extract_diagnosis (completion_text )
618+
619+ gold = (state .get ("info" ) or {}).get ("gold" , "" ) or answer
620+
621+ judge_response = await judge (prompt , diagnosis_text , gold , state , ** kwargs )
622+
623+ judge_response_clean = judge_response .strip ().upper ()
624+
625+ if "CORRECT" in judge_response_clean and "INCORRECT" not in judge_response_clean :
626+ return 1.0
627+ else :
628+ return 0.0
629+
630+ rubric .add_reward_func (diagnosis_reward_func , weight = 1.0 )
631+
632+ env .rubric = rubric
633+
629634 return env
630635
631636
0 commit comments