-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
244 lines (195 loc) · 8.59 KB
/
main.py
File metadata and controls
244 lines (195 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import os
from dotenv import load_dotenv
load_dotenv()
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.store.base import BaseStore
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.store.memory import InMemoryStore
from langchain_openai import ChatOpenAI
# Create model
model = ChatOpenAI(temperature=0)
CLINIC_NAME = "Good Health Clinic"
# This message will provide context to the LLM about its role and the "patient" data it should use.
MODEL_SYSTEM_MESSAGE = """You are a helpful medical assistant for {clinic_name}.
Use the patient's history to provide relevant, personalized appointment scheduling or advice.
Patient profile: {history}"""
# Instruction for how we update the patient profile (storing appointment data, medical notes, etc.).
UPDATE_PATIENT_PROFILE_INSTRUCTION = """Update the patient's medical/appointment profile with new information.
CURRENT PROFILE:
{history}
ANALYZE FOR:
1. Appointment history (dates, times, no-shows)
2. Medical preferences or concerns
3. Previous diagnoses or treatments
4. Medication usage or allergies
5. Follow-up needs
Focus on verified appointment and medical details only. Summarize key points clearly.
Update the profile based on this conversation:
"""
def check_condition(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""
A simple routing node that checks if the last user message
contains the word 'emergency'. If so, we return
{'decision': 'emergency_route'}, else {'decision': 'regular_route'}.
"""
user_msg = state["messages"][-1].content.lower()
if "emergency" in user_msg:
return {'decision': 'emergency_route'}
else:
return {'decision': 'regular_route'}
def handle_emergency(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""
A specialized function that might provide urgent instructions or
escalate the flow for an 'emergency' scenario.
"""
# We could use a model or just return a static response.
return {
"messages": [
SystemMessage(
content="We’ve detected an emergency. Please contact emergency services immediately or call our 24/7 urgent line: +43 00 00 00."
)
]
}
def call_model(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""Generates an AI response, leveraging the patient's history for context.
Args:
state (MessagesState): Current conversation messages
config (RunnableConfig): Runtime configuration with patient_id
store (BaseStore): Persistent storage for patient data
Returns:
dict: Generated response messages
"""
# 1. Retrieve patient ID and profile from store
patient_id = config["configurable"]["patient_id"]
namespace = ("patient_interactions", patient_id)
key = "patient_data_memory"
memory = store.get(namespace, key)
# 2. Extract existing history or set a default
history = memory.value.get('patient_data_memory') if memory else "No existing patient profile found."
# 3. Format the system message with the patient's context
system_msg = MODEL_SYSTEM_MESSAGE.format(history=history, clinic_name=CLINIC_NAME)
# 4. Generate the AI response
response = model.invoke([SystemMessage(content=system_msg)] + state["messages"])
return {"messages": response}
def write_memory(state: MessagesState, config: RunnableConfig, store: BaseStore):
"""Updates the patient's appointment/medical profile in persistent storage.
Args:
state (MessagesState): Current conversation messages
config (RunnableConfig): Runtime config containing patient_id
store (BaseStore): Persistent storage for patient data
"""
# 1. Retrieve patient history
patient_id = config["configurable"]["patient_id"]
namespace = ("patient_interactions", patient_id)
key = "patient_data_memory"
memory = store.get(namespace=namespace, key=key)
# 2. Extract existing profile or set a default
history = memory.value.get(key) if memory else "No existing history."
# 3. Generate updated profile content based on the new conversation
system_msg = UPDATE_PATIENT_PROFILE_INSTRUCTION.format(history=history)
new_insights = model.invoke([SystemMessage(content=system_msg)] + state['messages'])
# 4. Store updated profile
# Here we save the updated profile text under 'patient_data_memory'
store.put(namespace, key, {"patient_data_memory": new_insights.content})
# Build the graph
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("check_condition", check_condition)
builder.add_node("call_model", call_model)
builder.add_node("handle_emergency", handle_emergency)
builder.add_node("write_memory", write_memory)
# Add initial edge from START -> check_condition
builder.add_edge(START, "check_condition")
# builder.set_entry_point("check_condition")
# Use add_conditional_edges to branch:
# If 'decision' == 'emergency_route', go to 'handle_emergency'
# If 'decision' == 'regular_route', go to 'call_model'
builder.add_conditional_edges(
"check_condition",
lambda state: state["decision"],
{
"emergency_route": "handle_emergency",
"regular_route": "call_model",
"end": END
}
)
# After either call_model or handle_emergency, go to write_memory
builder.add_edge("handle_emergency", "write_memory")
builder.add_edge("call_model", "write_memory")
# Then from write_memory -> END
builder.add_edge("write_memory", END)
# Initialize memory stores
across_thread_memory = InMemoryStore() # Long-term storage for patient interactions
within_thread_memory = MemorySaver() # Keeps current conversation state
# Compile the graph with memory configuration
graph = builder.compile(
checkpointer=within_thread_memory, # Tracks conversation state in memory
store=across_thread_memory # Persists patient data
)
# Optionally, visualize the graph
# display(Image(graph.get_graph(xray=1).draw_png()))
png_graph = graph.get_graph().draw_mermaid_png()
with open("my_graph.png", "wb") as f:
f.write(png_graph)
# Configuration for the "patient"
config = {
"configurable": {
"thread_id": "1", # Current conversation ID
"patient_id": "1" # Identify the patient in the store
}
}
# Example patient message
input_msg = [
HumanMessage(content="Hi, I'm Taher. I'd like to schedule an appointment for a routine check-up.")
]
# Stream the result of the conversation
for chunk in graph.stream(
{"messages": input_msg}, # Current user message
config, # Our runtime configuration
stream_mode="values" # We just want the messages content
):
chunk["messages"][-1].pretty_print()
# print("\n---- EMERGENCY CASE ----")
# emergency_msg = [
# HumanMessage(content="This is an emergency! I'm experiencing severe chest pain.")
# ]
# for chunk in graph.stream({"messages": emergency_msg}, config, stream_mode="values"):
# chunk["messages"][-1].pretty_print()
# Now Taher replies with a date/time preference:
patient_followup_msg = [
HumanMessage(
content="Yes, I'd like to schedule it for next Tuesday around 10 AM, if possible."
)
]
for chunk in graph.stream(
{"messages": patient_followup_msg}, # Next user message
config, # Same config with patient_id="1"
stream_mode="values"
):
chunk["messages"][-1].pretty_print()
# Patient (Taher) responds, concluding the conversation and requesting a summary:
final_user_msg = [
HumanMessage(
content="No, that's all for now. Could you give me a brief summary of my appointment details?"
)
]
# Send this final message through the same graph:
for chunk in graph.stream(
{"messages": final_user_msg}, # Next user message
config, # Same runtime config { "configurable": { "patient_id":"1", ...}}
stream_mode="values"
):
# The AI's final response (including summary) is in chunk["messages"][-1]
chunk["messages"][-1].pretty_print()
# OPTIONAL: If you want to retrieve the entire "patient profile" from memory and display it:
# namespace = ("patient_interactions", "1") # patient_id = "1"
# key = "patient_data_memory"
# memory_data = across_thread_memory.get(namespace, key)
# if memory_data:
# patient_profile = memory_data.value.get("patient_data_memory")
# print("\n--- STORED PATIENT PROFILE ---")
# print(patient_profile)