|
| 1 | +"""Medical Triage System with Agent Delegation. |
| 2 | +
|
| 3 | +The `triage_agent` acts as the central decision-maker and orchestrator. |
| 4 | +The instructions helps it to "call tools to consult specialists or a senior doctor." |
| 5 | +It delegates the actual medical work (diagnosis or treatment planning) to other agents. |
| 6 | +
|
| 7 | +The two core functions act as the delegation mechanism: |
| 8 | +
|
| 9 | +- consult_specialist: This tool routes the complaint to a specific Specialist Agent |
| 10 | +(cardiology_agent, neurology_agent, etc.). This is Level 1 Delegation: Routing to expertise. |
| 11 | +
|
| 12 | +- consult_senior_doctor: This tool routes the complaint to a Senior Agent (senior_doctor_agent). |
| 13 | +This is Level 2 Delegation: Escalation for critical decision-making. |
| 14 | +
|
| 15 | +Demonstrates: |
| 16 | +- Master agent coordinating specialized sub-agents |
| 17 | +- Dynamic routing and delegation based on symptom analysis |
| 18 | +- Structured output |
| 19 | +
|
| 20 | +Run with: |
| 21 | +
|
| 22 | + uv run -m pydantic_ai_examples.medical_agent_delegation |
| 23 | +""" |
| 24 | + |
| 25 | +import asyncio |
| 26 | +from dataclasses import dataclass |
| 27 | +from datetime import UTC, datetime |
| 28 | +from enum import Enum |
| 29 | +from typing import Any |
| 30 | + |
| 31 | +from pydantic import BaseModel, Field |
| 32 | +from typing_extensions import TypedDict |
| 33 | + |
| 34 | +from pydantic_ai import Agent, ModelHTTPError, RunContext |
| 35 | + |
| 36 | +# Make sure to set CO_API_KEY token. Or Change the model and token based on relevant provider |
| 37 | +MODEL = 'cohere:command-r7b-12-2024' |
| 38 | + |
| 39 | + |
| 40 | +# Structured Outputs |
| 41 | +class Specialty(str, Enum): |
| 42 | + general = 'general' |
| 43 | + cardiology = 'cardiology' |
| 44 | + neurology = 'neurology' |
| 45 | + |
| 46 | + |
| 47 | +class MedicalReport(BaseModel): |
| 48 | + diagnosis: list[str] |
| 49 | + differential: list[str] |
| 50 | + recommended_tests: list[str] |
| 51 | + immediate_actions: list[str] |
| 52 | + estimated_time_minutes: int |
| 53 | + |
| 54 | + |
| 55 | +class TreatmentPlan(BaseModel): |
| 56 | + plan_summary: str = Field( |
| 57 | + description='The structured treatment plan from the senior doctor' |
| 58 | + ) |
| 59 | + refer_to_specialist: Specialty | None = Field( |
| 60 | + description='Specialty to route the patient to for further treatment, if necessary' |
| 61 | + ) |
| 62 | + follow_up_days: int |
| 63 | + |
| 64 | + |
| 65 | +class TriageFinalOutput(BaseModel): |
| 66 | + """The final structured output containing the result of the entire flow.""" |
| 67 | + |
| 68 | + specialty: Specialty | None = None |
| 69 | + final_report: MedicalReport | None = None |
| 70 | + treatment_plan: TreatmentPlan | None = None |
| 71 | + final_status: str = Field( |
| 72 | + ..., description="Status: 'resolved_by_specialist' or 'escalated'" |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +# Shared Dependency |
| 77 | +@dataclass |
| 78 | +class PatientInfo: |
| 79 | + patient_id: str |
| 80 | + age: int |
| 81 | + known_conditions: list[str] |
| 82 | + |
| 83 | + |
| 84 | +class TestPatient(TypedDict): |
| 85 | + complaint: str |
| 86 | + patient: PatientInfo |
| 87 | + |
| 88 | + |
| 89 | +# Specialist and Senior Agents |
| 90 | +gp_agent = Agent( |
| 91 | + MODEL, |
| 92 | + output_type=MedicalReport, |
| 93 | + deps_type=PatientInfo, |
| 94 | + system_prompt=""" |
| 95 | + You are a general practitioner. |
| 96 | + """, |
| 97 | +) |
| 98 | + |
| 99 | +cardiology_agent = Agent( |
| 100 | + MODEL, |
| 101 | + output_type=MedicalReport, |
| 102 | + deps_type=PatientInfo, |
| 103 | + system_prompt=""" |
| 104 | + You are a cardiology specialist. |
| 105 | + """, |
| 106 | +) |
| 107 | + |
| 108 | +neurology_agent = Agent( |
| 109 | + MODEL, |
| 110 | + output_type=MedicalReport, |
| 111 | + deps_type=PatientInfo, |
| 112 | + system_prompt=""" |
| 113 | + You are a neurology specialist. |
| 114 | + """, |
| 115 | +) |
| 116 | + |
| 117 | +senior_doctor_agent = Agent( |
| 118 | + MODEL, |
| 119 | + output_type=TreatmentPlan, |
| 120 | + deps_type=PatientInfo, |
| 121 | + system_prompt=""" |
| 122 | + You are a senior clinician overseeing complex or ambiguous cases. |
| 123 | + Integrate all prior findings to produce a clear treatment plan. |
| 124 | + """, |
| 125 | +) |
| 126 | + |
| 127 | +SPECIALIST_MAP = { |
| 128 | + 'general': gp_agent, |
| 129 | + 'cardiology': cardiology_agent, |
| 130 | + 'neurology': neurology_agent, |
| 131 | +} |
| 132 | + |
| 133 | +# Agent-as-Orchestrator: triage_agent with Delegation Tools |
| 134 | +triage_agent = Agent( |
| 135 | + MODEL, |
| 136 | + output_type=TriageFinalOutput, |
| 137 | + deps_type=PatientInfo, |
| 138 | + system_prompt=""" |
| 139 | + You are a triage clinician coordinating medical workflow. |
| 140 | + You can call tools to consult specialists or a senior doctor. |
| 141 | +
|
| 142 | + AVAILABLE SPECIALTIES: |
| 143 | + - "general": General practitioner for common issues |
| 144 | + - "cardiology": For heart, chest pain, cardiac symptoms |
| 145 | + - "neurology": For brain, nerve, stroke, headache symptoms |
| 146 | +
|
| 147 | + Always produce a structured TriageFinalOutput. |
| 148 | + """, |
| 149 | +) |
| 150 | + |
| 151 | + |
| 152 | +@triage_agent.tool |
| 153 | +async def consult_specialist( |
| 154 | + ctx: RunContext[PatientInfo], |
| 155 | + specialty: Specialty, |
| 156 | + question: str, |
| 157 | +) -> TriageFinalOutput | str: |
| 158 | + """Consult the appropriate specialist for expert consultation.""" |
| 159 | + specialist_agent = SPECIALIST_MAP.get(specialty) |
| 160 | + print(f'Proceed with specialist - {specialty}') |
| 161 | + if not specialist_agent: |
| 162 | + print('Selected specialist does not exists!') |
| 163 | + return f'No specialist found for {specialty.name}.' |
| 164 | + |
| 165 | + result = await specialist_agent.run(f'Consultation: {question}', deps=ctx.deps) |
| 166 | + report: MedicalReport = result.output |
| 167 | + |
| 168 | + return TriageFinalOutput( |
| 169 | + final_status='resolved_by_specialist', |
| 170 | + specialty=specialty, |
| 171 | + final_report=report, |
| 172 | + ) |
| 173 | + |
| 174 | + |
| 175 | +@triage_agent.tool |
| 176 | +async def consult_senior_doctor( |
| 177 | + ctx: RunContext[PatientInfo], reason_for_escalation: str, initial_complaint: str |
| 178 | +) -> TriageFinalOutput: |
| 179 | + """Consult senior doctor in case of escalation and emergency cases. |
| 180 | +
|
| 181 | + Immediately escalates the case to the senior clinician for severe cases and for a final TreatmentPlan. |
| 182 | + Use this for high severity, critical, or ambiguous cases. |
| 183 | +
|
| 184 | + Args: |
| 185 | + ctx: Pydantic AI agent RunContext |
| 186 | + reason_for_escalation: Summary of why the case must be escalated (e.g., "Severe pain, possible cardiac event"). |
| 187 | + initial_complaint: The patient's original complaint. |
| 188 | + """ |
| 189 | + patient = ctx.deps |
| 190 | + senior_note = f'Reason: {reason_for_escalation}\nComplaint and context:\n{initial_complaint}\nPatient: {patient.patient_id}, age {patient.age}\n' |
| 191 | + |
| 192 | + print('Direct escalation triggered by Triage LLM.') |
| 193 | + treatment_plan = None |
| 194 | + try: |
| 195 | + result = await senior_doctor_agent.run( |
| 196 | + f'Consultation for: {senior_note}', deps=ctx.deps |
| 197 | + ) |
| 198 | + treatment_plan = result.output |
| 199 | + except ModelHTTPError as e: |
| 200 | + # Handle case where LLM fails to provide TreatmentPlan structure |
| 201 | + treatment_plan = TreatmentPlan( |
| 202 | + plan_summary=f'Consultation failed due to API error: {e.status_code}. Requires manual review.', |
| 203 | + refer_to_specialist=None, |
| 204 | + follow_up_days=1, |
| 205 | + ) |
| 206 | + |
| 207 | + return TriageFinalOutput( |
| 208 | + final_status='escalated', |
| 209 | + treatment_plan=treatment_plan, |
| 210 | + ) |
| 211 | + |
| 212 | + |
| 213 | +# Coordinator System |
| 214 | +class MedicalTriageSystem: |
| 215 | + """Coordinator that invokes triage_agent as the orchestrator.""" |
| 216 | + |
| 217 | + def __init__(self): |
| 218 | + self.triage = triage_agent |
| 219 | + self.medical_history: list[dict[str, Any]] = [] |
| 220 | + |
| 221 | + async def handle_patient( |
| 222 | + self, complaint: str, patient: PatientInfo |
| 223 | + ) -> dict[str, str]: |
| 224 | + timestamp = datetime.now(UTC).isoformat() |
| 225 | + print(f'\n[{timestamp}] Processing complaint: {complaint}') |
| 226 | + |
| 227 | + triage_prompt = ( |
| 228 | + f'Patient {patient.patient_id}, age {patient.age}\n' |
| 229 | + f'Complaint: {complaint}\n' |
| 230 | + f'Known conditions: {patient.known_conditions}\n' |
| 231 | + f'If necessary, use your tools to consult specialists or senior doctor.' |
| 232 | + ) |
| 233 | + |
| 234 | + triage_result = await self.triage.run(triage_prompt, deps=patient) |
| 235 | + final_output: TriageFinalOutput = triage_result.output |
| 236 | + |
| 237 | + record = { |
| 238 | + 'timestamp': timestamp, |
| 239 | + 'patient_id': patient.patient_id, |
| 240 | + 'path': final_output.final_status, |
| 241 | + 'specialty': final_output.specialty, |
| 242 | + 'report_summary': final_output.final_report.diagnosis |
| 243 | + if final_output.final_report |
| 244 | + else 'N/A', |
| 245 | + 'treatment_summary': final_output.treatment_plan.plan_summary |
| 246 | + if final_output.treatment_plan |
| 247 | + else 'N/A', |
| 248 | + } |
| 249 | + self.medical_history.append(record) |
| 250 | + |
| 251 | + return final_output.model_dump() |
| 252 | + |
| 253 | + |
| 254 | +async def demo_medical_triage(): |
| 255 | + system = MedicalTriageSystem() |
| 256 | + |
| 257 | + test_patients: list[TestPatient] = [ |
| 258 | + { |
| 259 | + 'complaint': 'Sudden severe chest pain radiating to left arm and shortness of breath.', |
| 260 | + 'patient': PatientInfo( |
| 261 | + patient_id='P001', age=64, known_conditions=['hypertension'] |
| 262 | + ), |
| 263 | + }, |
| 264 | + { |
| 265 | + 'complaint': 'Intermittent headaches for 2 weeks, mild nausea, no weakness.', |
| 266 | + 'patient': PatientInfo(patient_id='P002', age=34, known_conditions=[]), |
| 267 | + }, |
| 268 | + { |
| 269 | + 'complaint': "Sudden onset of the worst headache of my life, followed by blurry vision and now I can't feel my left leg. I took aspirin an hour ago.", |
| 270 | + 'patient': PatientInfo( |
| 271 | + patient_id='P003', |
| 272 | + age=71, |
| 273 | + known_conditions=['Type 2 Diabetes', 'Chronic Migraines'], |
| 274 | + ), |
| 275 | + }, |
| 276 | + { |
| 277 | + 'complaint': 'Hard to breath and faint every few minutes.', |
| 278 | + 'patient': PatientInfo(patient_id='P003', age=71, known_conditions=[]), |
| 279 | + }, |
| 280 | + ] |
| 281 | + |
| 282 | + for entry in test_patients: |
| 283 | + print(f'Processing patient {entry["patient"].patient_id}') |
| 284 | + result = await system.handle_patient(entry['complaint'], entry['patient']) |
| 285 | + print('Result:', result) |
| 286 | + |
| 287 | + print('\nMEDICAL HISTORY SUMMARY:') |
| 288 | + for history in system.medical_history: |
| 289 | + print( |
| 290 | + f'- {history["timestamp"]} | Patient {history["patient_id"]} | Path: {history["path"]}' |
| 291 | + ) |
| 292 | + |
| 293 | + |
| 294 | +if __name__ == '__main__': |
| 295 | + asyncio.run(demo_medical_triage()) |
0 commit comments