Skip to content

Commit a0e3bd5

Browse files
committed
Added example to illustrate triage agent with delegation
1 parent 92289ce commit a0e3bd5

File tree

2 files changed

+340
-0
lines changed

2 files changed

+340
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
Medical triage and delegation system built with **Pydantic AI**, demonstrating how an orchestrator agent (`triage_agent`) coordinates multiple specialized agents (e.g. cardiology, neurology, and senior clinician).
2+
3+
Demonstrates:
4+
- [Agent delegation and coordination](../multi-agent-applications.md#agent-delegation)
5+
- [structured `output_type`](../output.md#structured-output)
6+
- [tools](../tools.md)
7+
8+
---
9+
10+
## Overview
11+
12+
This example shows how to use **multiple Pydantic AI agents** to simulate a medical triage workflow.
13+
14+
The system includes:
15+
- **General Practitioner, Cardiology, and Neurology agents** — for Level 1 consultation.
16+
- **Senior Doctor agent** — for escalations and treatment planning.
17+
- **Triage Agent (Coordinator)** — which decides which tool to invoke and when to escalate.
18+
19+
The `triage_agent` uses two tools:
20+
1. `consult_specialist` — routes the complaint to a domain specialist.
21+
2. `consult_senior_doctor` — escalates the case for critical or ambiguous scenarios.
22+
23+
Each specialist produces a structured `MedicalReport`, and the senior doctor produces a structured `TreatmentPlan`.
24+
The orchestrator then compiles both into a final `TriageFinalOutput`.
25+
26+
---
27+
28+
## Running the Example
29+
30+
With [dependencies installed and environment variables set](./setup.md#usage), run:
31+
32+
```bash
33+
python -m pydantic_ai_examples.medical_agent_delegation
34+
35+
Make sure to set a valid **Cohere API key** or replace the model reference:
36+
37+
```bash
38+
export CO_API_KEY="your-cohere-api-key"
39+
```
40+
41+
You may also switch to an OpenAI or Anthropic model if preferred:
42+
43+
```python
44+
MODEL = "openai:gpt-4o"
45+
```
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""Medical Triage System with Agent Delegation.
2+
3+
The `triage_agent` acts as the central decision-maker and orchestrator.
4+
The instructions helps it to "call tools to consult specialists or a senior doctor."
5+
It delegates the actual medical work (diagnosis or treatment planning) to other agents.
6+
7+
The two core functions act as the delegation mechanism:
8+
9+
- consult_specialist: This tool routes the complaint to a specific Specialist Agent
10+
(cardiology_agent, neurology_agent, etc.). This is Level 1 Delegation: Routing to expertise.
11+
12+
- consult_senior_doctor: This tool routes the complaint to a Senior Agent (senior_doctor_agent).
13+
This is Level 2 Delegation: Escalation for critical decision-making.
14+
15+
Demonstrates:
16+
- Master agent coordinating specialized sub-agents
17+
- Dynamic routing and delegation based on symptom analysis
18+
- Structured output
19+
20+
Run with:
21+
22+
uv run -m pydantic_ai_examples.medical_agent_delegation
23+
"""
24+
25+
import asyncio
26+
from dataclasses import dataclass
27+
from datetime import UTC, datetime
28+
from enum import Enum
29+
from typing import Any
30+
31+
from pydantic import BaseModel, Field
32+
from typing_extensions import TypedDict
33+
34+
from pydantic_ai import Agent, ModelHTTPError, RunContext
35+
36+
# Make sure to set CO_API_KEY token. Or Change the model and token based on relevant provider
37+
MODEL = 'cohere:command-r7b-12-2024'
38+
39+
40+
# Structured Outputs
41+
class Specialty(str, Enum):
42+
general = 'general'
43+
cardiology = 'cardiology'
44+
neurology = 'neurology'
45+
46+
47+
class MedicalReport(BaseModel):
48+
diagnosis: list[str]
49+
differential: list[str]
50+
recommended_tests: list[str]
51+
immediate_actions: list[str]
52+
estimated_time_minutes: int
53+
54+
55+
class TreatmentPlan(BaseModel):
56+
plan_summary: str = Field(
57+
description='The structured treatment plan from the senior doctor'
58+
)
59+
refer_to_specialist: Specialty | None = Field(
60+
description='Specialty to route the patient to for further treatment, if necessary'
61+
)
62+
follow_up_days: int
63+
64+
65+
class TriageFinalOutput(BaseModel):
66+
"""The final structured output containing the result of the entire flow."""
67+
68+
specialty: Specialty | None = None
69+
final_report: MedicalReport | None = None
70+
treatment_plan: TreatmentPlan | None = None
71+
final_status: str = Field(
72+
..., description="Status: 'resolved_by_specialist' or 'escalated'"
73+
)
74+
75+
76+
# Shared Dependency
77+
@dataclass
78+
class PatientInfo:
79+
patient_id: str
80+
age: int
81+
known_conditions: list[str]
82+
83+
84+
class TestPatient(TypedDict):
85+
complaint: str
86+
patient: PatientInfo
87+
88+
89+
# Specialist and Senior Agents
90+
gp_agent = Agent(
91+
MODEL,
92+
output_type=MedicalReport,
93+
deps_type=PatientInfo,
94+
system_prompt="""
95+
You are a general practitioner.
96+
""",
97+
)
98+
99+
cardiology_agent = Agent(
100+
MODEL,
101+
output_type=MedicalReport,
102+
deps_type=PatientInfo,
103+
system_prompt="""
104+
You are a cardiology specialist.
105+
""",
106+
)
107+
108+
neurology_agent = Agent(
109+
MODEL,
110+
output_type=MedicalReport,
111+
deps_type=PatientInfo,
112+
system_prompt="""
113+
You are a neurology specialist.
114+
""",
115+
)
116+
117+
senior_doctor_agent = Agent(
118+
MODEL,
119+
output_type=TreatmentPlan,
120+
deps_type=PatientInfo,
121+
system_prompt="""
122+
You are a senior clinician overseeing complex or ambiguous cases.
123+
Integrate all prior findings to produce a clear treatment plan.
124+
""",
125+
)
126+
127+
SPECIALIST_MAP = {
128+
'general': gp_agent,
129+
'cardiology': cardiology_agent,
130+
'neurology': neurology_agent,
131+
}
132+
133+
# Agent-as-Orchestrator: triage_agent with Delegation Tools
134+
triage_agent = Agent(
135+
MODEL,
136+
output_type=TriageFinalOutput,
137+
deps_type=PatientInfo,
138+
system_prompt="""
139+
You are a triage clinician coordinating medical workflow.
140+
You can call tools to consult specialists or a senior doctor.
141+
142+
AVAILABLE SPECIALTIES:
143+
- "general": General practitioner for common issues
144+
- "cardiology": For heart, chest pain, cardiac symptoms
145+
- "neurology": For brain, nerve, stroke, headache symptoms
146+
147+
Always produce a structured TriageFinalOutput.
148+
""",
149+
)
150+
151+
152+
@triage_agent.tool
153+
async def consult_specialist(
154+
ctx: RunContext[PatientInfo],
155+
specialty: Specialty,
156+
question: str,
157+
) -> TriageFinalOutput | str:
158+
"""Consult the appropriate specialist for expert consultation."""
159+
specialist_agent = SPECIALIST_MAP.get(specialty)
160+
print(f'Proceed with specialist - {specialty}')
161+
if not specialist_agent:
162+
print('Selected specialist does not exists!')
163+
return f'No specialist found for {specialty.name}.'
164+
165+
result = await specialist_agent.run(f'Consultation: {question}', deps=ctx.deps)
166+
report: MedicalReport = result.output
167+
168+
return TriageFinalOutput(
169+
final_status='resolved_by_specialist',
170+
specialty=specialty,
171+
final_report=report,
172+
)
173+
174+
175+
@triage_agent.tool
176+
async def consult_senior_doctor(
177+
ctx: RunContext[PatientInfo], reason_for_escalation: str, initial_complaint: str
178+
) -> TriageFinalOutput:
179+
"""Consult senior doctor in case of escalation and emergency cases.
180+
181+
Immediately escalates the case to the senior clinician for severe cases and for a final TreatmentPlan.
182+
Use this for high severity, critical, or ambiguous cases.
183+
184+
Args:
185+
ctx: Pydantic AI agent RunContext
186+
reason_for_escalation: Summary of why the case must be escalated (e.g., "Severe pain, possible cardiac event").
187+
initial_complaint: The patient's original complaint.
188+
"""
189+
patient = ctx.deps
190+
senior_note = f'Reason: {reason_for_escalation}\nComplaint and context:\n{initial_complaint}\nPatient: {patient.patient_id}, age {patient.age}\n'
191+
192+
print('Direct escalation triggered by Triage LLM.')
193+
treatment_plan = None
194+
try:
195+
result = await senior_doctor_agent.run(
196+
f'Consultation for: {senior_note}', deps=ctx.deps
197+
)
198+
treatment_plan = result.output
199+
except ModelHTTPError as e:
200+
# Handle case where LLM fails to provide TreatmentPlan structure
201+
treatment_plan = TreatmentPlan(
202+
plan_summary=f'Consultation failed due to API error: {e.status_code}. Requires manual review.',
203+
refer_to_specialist=None,
204+
follow_up_days=1,
205+
)
206+
207+
return TriageFinalOutput(
208+
final_status='escalated',
209+
treatment_plan=treatment_plan,
210+
)
211+
212+
213+
# Coordinator System
214+
class MedicalTriageSystem:
215+
"""Coordinator that invokes triage_agent as the orchestrator."""
216+
217+
def __init__(self):
218+
self.triage = triage_agent
219+
self.medical_history: list[dict[str, Any]] = []
220+
221+
async def handle_patient(
222+
self, complaint: str, patient: PatientInfo
223+
) -> dict[str, str]:
224+
timestamp = datetime.now(UTC).isoformat()
225+
print(f'\n[{timestamp}] Processing complaint: {complaint}')
226+
227+
triage_prompt = (
228+
f'Patient {patient.patient_id}, age {patient.age}\n'
229+
f'Complaint: {complaint}\n'
230+
f'Known conditions: {patient.known_conditions}\n'
231+
f'If necessary, use your tools to consult specialists or senior doctor.'
232+
)
233+
234+
triage_result = await self.triage.run(triage_prompt, deps=patient)
235+
final_output: TriageFinalOutput = triage_result.output
236+
237+
record = {
238+
'timestamp': timestamp,
239+
'patient_id': patient.patient_id,
240+
'path': final_output.final_status,
241+
'specialty': final_output.specialty,
242+
'report_summary': final_output.final_report.diagnosis
243+
if final_output.final_report
244+
else 'N/A',
245+
'treatment_summary': final_output.treatment_plan.plan_summary
246+
if final_output.treatment_plan
247+
else 'N/A',
248+
}
249+
self.medical_history.append(record)
250+
251+
return final_output.model_dump()
252+
253+
254+
async def demo_medical_triage():
255+
system = MedicalTriageSystem()
256+
257+
test_patients: list[TestPatient] = [
258+
{
259+
'complaint': 'Sudden severe chest pain radiating to left arm and shortness of breath.',
260+
'patient': PatientInfo(
261+
patient_id='P001', age=64, known_conditions=['hypertension']
262+
),
263+
},
264+
{
265+
'complaint': 'Intermittent headaches for 2 weeks, mild nausea, no weakness.',
266+
'patient': PatientInfo(patient_id='P002', age=34, known_conditions=[]),
267+
},
268+
{
269+
'complaint': "Sudden onset of the worst headache of my life, followed by blurry vision and now I can't feel my left leg. I took aspirin an hour ago.",
270+
'patient': PatientInfo(
271+
patient_id='P003',
272+
age=71,
273+
known_conditions=['Type 2 Diabetes', 'Chronic Migraines'],
274+
),
275+
},
276+
{
277+
'complaint': 'Hard to breath and faint every few minutes.',
278+
'patient': PatientInfo(patient_id='P003', age=71, known_conditions=[]),
279+
},
280+
]
281+
282+
for entry in test_patients:
283+
print(f'Processing patient {entry["patient"].patient_id}')
284+
result = await system.handle_patient(entry['complaint'], entry['patient'])
285+
print('Result:', result)
286+
287+
print('\nMEDICAL HISTORY SUMMARY:')
288+
for history in system.medical_history:
289+
print(
290+
f'- {history["timestamp"]} | Patient {history["patient_id"]} | Path: {history["path"]}'
291+
)
292+
293+
294+
if __name__ == '__main__':
295+
asyncio.run(demo_medical_triage())

0 commit comments

Comments
 (0)