-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstreamlit_app.py
More file actions
518 lines (434 loc) · 21.2 KB
/
streamlit_app.py
File metadata and controls
518 lines (434 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
import asyncio
import os
import urllib.parse
import uuid
from collections.abc import AsyncGenerator
import streamlit as st
from dotenv import load_dotenv
from pydantic import ValidationError
from client import AgentClient, AgentClientError
from schema import ChatHistory, ChatMessage
from schema.task_data import TaskData, TaskDataStatus
# A Streamlit app for interacting with the langgraph agent via a simple chat interface.
# The app has three main functions which are all run async:
# - main() - sets up the streamlit app and high level structure
# - draw_messages() - draws a set of chat messages - either replaying existing messages
# or streaming new ones.
# - handle_feedback() - Draws a feedback widget and records feedback from the user.
# The app heavily uses AgentClient to interact with the agent's FastAPI endpoints.
APP_TITLE = "Agent Service Toolkit"
APP_ICON = "🧰"
USER_ID_COOKIE = "user_id"
def get_or_create_user_id() -> str:
"""Get the user ID from session state or URL parameters, or create a new one if it doesn't exist."""
# Check if user_id exists in session state
if USER_ID_COOKIE in st.session_state:
return st.session_state[USER_ID_COOKIE]
# Try to get from URL parameters using the new st.query_params
if USER_ID_COOKIE in st.query_params:
user_id = st.query_params[USER_ID_COOKIE]
st.session_state[USER_ID_COOKIE] = user_id
return user_id
# Generate a new user_id if not found
user_id = str(uuid.uuid4())
# Store in session state for this session
st.session_state[USER_ID_COOKIE] = user_id
# Also add to URL parameters so it can be bookmarked/shared
st.query_params[USER_ID_COOKIE] = user_id
return user_id
async def main() -> None:
st.set_page_config(
page_title=APP_TITLE,
page_icon=APP_ICON,
menu_items={},
)
# Hide the streamlit upper-right chrome
st.html(
"""
<style>
[data-testid="stStatusWidget"] {
visibility: hidden;
height: 0%;
position: fixed;
}
</style>
""",
)
if st.get_option("client.toolbarMode") != "minimal":
st.set_option("client.toolbarMode", "minimal")
await asyncio.sleep(0.1)
st.rerun()
# Get or create user ID
user_id = get_or_create_user_id()
if "agent_client" not in st.session_state:
load_dotenv()
agent_url = os.getenv("AGENT_URL")
if not agent_url:
host = os.getenv("HOST", "0.0.0.0")
port = os.getenv("PORT", 8080)
agent_url = f"http://{host}:{port}"
try:
with st.spinner("Connecting to agent service..."):
st.session_state.agent_client = AgentClient(base_url=agent_url)
except AgentClientError as e:
st.error(f"Error connecting to agent service at {agent_url}: {e}")
st.markdown("The service might be booting up. Try again in a few seconds.")
st.stop()
agent_client: AgentClient = st.session_state.agent_client
if "thread_id" not in st.session_state:
thread_id = st.query_params.get("thread_id")
if not thread_id:
thread_id = str(uuid.uuid4())
messages = []
else:
try:
messages: ChatHistory = agent_client.get_history(thread_id=thread_id).messages
except AgentClientError:
st.error("No message history found for this Thread ID.")
messages = []
st.session_state.messages = messages
st.session_state.thread_id = thread_id
# Config options
with st.sidebar:
st.header(f"{APP_ICON} {APP_TITLE}")
""
"Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit"
""
if st.button(":material/chat: New Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.thread_id = str(uuid.uuid4())
st.rerun()
with st.popover(":material/settings: Settings", use_container_width=True):
model_idx = agent_client.info.models.index(agent_client.info.default_model)
model = st.selectbox("LLM to use", options=agent_client.info.models, index=model_idx)
agent_list = [a.key for a in agent_client.info.agents]
agent_idx = agent_list.index(agent_client.info.default_agent)
agent_client.agent = st.selectbox(
"Agent to use",
options=agent_list,
index=agent_idx,
)
use_streaming = st.toggle("Stream results", value=True)
# Display user ID (for debugging or user information)
st.text_input("User ID (read-only)", value=user_id, disabled=True)
@st.dialog("Architecture")
def architecture_dialog() -> None:
st.image(
"https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png?raw=true"
)
"[View full size on Github](https://github.com/JoshuaC215/agent-service-toolkit/blob/main/media/agent_architecture.png)"
st.caption(
"App hosted on [Streamlit Cloud](https://share.streamlit.io/) with FastAPI service running in [Azure](https://learn.microsoft.com/en-us/azure/app-service/)"
)
if st.button(":material/schema: Architecture", use_container_width=True):
architecture_dialog()
with st.popover(":material/policy: Privacy", use_container_width=True):
st.write(
"Prompts, responses and feedback in this app are anonymously recorded and saved to LangSmith for product evaluation and improvement purposes only."
)
@st.dialog("Share/resume chat")
def share_chat_dialog() -> None:
session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0]
st_base_url = urllib.parse.urlunparse(
[session.client.request.protocol, session.client.request.host, "", "", "", ""]
)
# if it's not localhost, switch to https by default
if not st_base_url.startswith("https") and "localhost" not in st_base_url:
st_base_url = st_base_url.replace("http", "https")
# Include both thread_id and user_id in the URL for sharing to maintain user identity
chat_url = (
f"{st_base_url}?thread_id={st.session_state.thread_id}&{USER_ID_COOKIE}={user_id}"
)
st.markdown(f"**Chat URL:**\n```text\n{chat_url}\n```")
st.info("Copy the above URL to share or revisit this chat")
if st.button(":material/upload: Share/resume chat", use_container_width=True):
share_chat_dialog()
"[View the source code](https://github.com/JoshuaC215/agent-service-toolkit)"
st.caption(
"Made with :material/favorite: by [Joshua](https://www.linkedin.com/in/joshua-k-carroll/) in Oakland"
)
# Draw existing messages
messages: list[ChatMessage] = st.session_state.messages
if len(messages) == 0:
match agent_client.agent:
case "chatbot":
WELCOME = "Hello! I'm a simple chatbot. Ask me anything!"
case "interrupt-agent":
WELCOME = "Hello! I'm an interrupt agent. Tell me your birthday and I will predict your personality!"
case "research-assistant":
WELCOME = "Hello! I'm an AI-powered research assistant with web search and a calculator. Ask me anything!"
case "rag-assistant":
WELCOME = """Hello! I'm an AI-powered Company Policy & HR assistant with access to AcmeTech's Employee Handbook.
I can help you find information about benefits, remote work, time-off policies, company values, and more. Ask me anything!"""
case _:
WELCOME = "Hello! I'm an AI agent. Ask me anything!"
with st.chat_message("ai"):
st.write(WELCOME)
# draw_messages() expects an async iterator over messages
async def amessage_iter() -> AsyncGenerator[ChatMessage, None]:
for m in messages:
yield m
await draw_messages(amessage_iter())
# Generate new message if the user provided new input
if user_input := st.chat_input():
messages.append(ChatMessage(type="human", content=user_input))
st.chat_message("human").write(user_input)
try:
if use_streaming:
stream = agent_client.astream(
message=user_input,
model=model,
thread_id=st.session_state.thread_id,
user_id=user_id,
)
await draw_messages(stream, is_new=True)
else:
response = await agent_client.ainvoke(
message=user_input,
model=model,
thread_id=st.session_state.thread_id,
user_id=user_id,
)
messages.append(response)
st.chat_message("ai").write(response.content)
st.rerun() # Clear stale containers
except AgentClientError as e:
st.error(f"Error generating response: {e}")
st.stop()
# If messages have been generated, show feedback widget
if len(messages) > 0 and st.session_state.last_message:
with st.session_state.last_message:
await handle_feedback()
async def draw_messages(
messages_agen: AsyncGenerator[ChatMessage | str, None],
is_new: bool = False,
) -> None:
"""
Draws a set of chat messages - either replaying existing messages
or streaming new ones.
This function has additional logic to handle streaming tokens and tool calls.
- Use a placeholder container to render streaming tokens as they arrive.
- Use a status container to render tool calls. Track the tool inputs and outputs
and update the status container accordingly.
The function also needs to track the last message container in session state
since later messages can draw to the same container. This is also used for
drawing the feedback widget in the latest chat message.
Args:
messages_aiter: An async iterator over messages to draw.
is_new: Whether the messages are new or not.
"""
# Keep track of the last message container
last_message_type = None
st.session_state.last_message = None
# Placeholder for intermediate streaming tokens
streaming_content = ""
streaming_placeholder = None
# Iterate over the messages and draw them
while msg := await anext(messages_agen, None):
# str message represents an intermediate token being streamed
if isinstance(msg, str):
# If placeholder is empty, this is the first token of a new message
# being streamed. We need to do setup.
if not streaming_placeholder:
if last_message_type != "ai":
last_message_type = "ai"
st.session_state.last_message = st.chat_message("ai")
with st.session_state.last_message:
streaming_placeholder = st.empty()
streaming_content += msg
streaming_placeholder.write(streaming_content)
continue
if not isinstance(msg, ChatMessage):
st.error(f"Unexpected message type: {type(msg)}")
st.write(msg)
st.stop()
match msg.type:
# A message from the user, the easiest case
case "human":
last_message_type = "human"
st.chat_message("human").write(msg.content)
# A message from the agent is the most complex case, since we need to
# handle streaming tokens and tool calls.
case "ai":
# If we're rendering new messages, store the message in session state
if is_new:
st.session_state.messages.append(msg)
# If the last message type was not AI, create a new chat message
if last_message_type != "ai":
last_message_type = "ai"
st.session_state.last_message = st.chat_message("ai")
with st.session_state.last_message:
# If the message has content, write it out.
# Reset the streaming variables to prepare for the next message.
if msg.content:
if streaming_placeholder:
streaming_placeholder.write(msg.content)
streaming_content = ""
streaming_placeholder = None
else:
st.write(msg.content)
if msg.tool_calls:
# Create a status container for each tool call and store the
# status container by ID to ensure results are mapped to the
# correct status container.
call_results = {}
for tool_call in msg.tool_calls:
# Use different labels for transfer vs regular tool calls
if "transfer_to" in tool_call["name"]:
label = f"""💼 Sub Agent: {tool_call["name"]}"""
else:
label = f"""🛠️ Tool Call: {tool_call["name"]}"""
status = st.status(
label,
state="running" if is_new else "complete",
)
call_results[tool_call["id"]] = status
# Expect one ToolMessage for each tool call.
for tool_call in msg.tool_calls:
if "transfer_to" in tool_call["name"]:
status = call_results[tool_call["id"]]
status.update(expanded=True)
await handle_sub_agent_msgs(messages_agen, status, is_new)
break
# Only non-transfer tool calls reach this point
status = call_results[tool_call["id"]]
status.write("Input:")
status.write(tool_call["args"])
tool_result: ChatMessage = await anext(messages_agen)
if tool_result.type != "tool":
st.error(f"Unexpected ChatMessage type: {tool_result.type}")
st.write(tool_result)
st.stop()
# Record the message if it's new, and update the correct
# status container with the result
if is_new:
st.session_state.messages.append(tool_result)
if tool_result.tool_call_id:
status = call_results[tool_result.tool_call_id]
status.write("Output:")
status.write(tool_result.content)
status.update(state="complete")
case "custom":
# CustomData example used by the bg-task-agent
# See:
# - src/agents/utils.py CustomData
# - src/agents/bg_task_agent/task.py
try:
task_data: TaskData = TaskData.model_validate(msg.custom_data)
except ValidationError:
st.error("Unexpected CustomData message received from agent")
st.write(msg.custom_data)
st.stop()
if is_new:
st.session_state.messages.append(msg)
if last_message_type != "task":
last_message_type = "task"
st.session_state.last_message = st.chat_message(
name="task", avatar=":material/manufacturing:"
)
with st.session_state.last_message:
status = TaskDataStatus()
status.add_and_draw_task_data(task_data)
# In case of an unexpected message type, log an error and stop
case _:
st.error(f"Unexpected ChatMessage type: {msg.type}")
st.write(msg)
st.stop()
async def handle_feedback() -> None:
"""Draws a feedback widget and records feedback from the user."""
# Keep track of last feedback sent to avoid sending duplicates
if "last_feedback" not in st.session_state:
st.session_state.last_feedback = (None, None)
latest_run_id = st.session_state.messages[-1].run_id
feedback = st.feedback("stars", key=latest_run_id)
# If the feedback value or run ID has changed, send a new feedback record
if feedback is not None and (latest_run_id, feedback) != st.session_state.last_feedback:
# Normalize the feedback value (an index) to a score between 0 and 1
normalized_score = (feedback + 1) / 5.0
agent_client: AgentClient = st.session_state.agent_client
try:
await agent_client.acreate_feedback(
run_id=latest_run_id,
key="human-feedback-stars",
score=normalized_score,
kwargs={"comment": "In-line human feedback"},
)
except AgentClientError as e:
st.error(f"Error recording feedback: {e}")
st.stop()
st.session_state.last_feedback = (latest_run_id, feedback)
st.toast("Feedback recorded", icon=":material/reviews:")
async def handle_sub_agent_msgs(messages_agen, status, is_new):
"""
This function segregates agent output into a status container.
It handles all messages after the initial tool call message
until it reaches the final AI message.
Enhanced to support nested multi-agent hierarchies with handoff back messages.
Args:
messages_agen: Async generator of messages
status: the status container for the current agent
is_new: Whether messages are new or replayed
"""
nested_popovers = {}
# looking for the transfer Success tool call message
first_msg = await anext(messages_agen)
if is_new:
st.session_state.messages.append(first_msg)
# Continue reading until we get an explicit handoff back
while True:
# Read next message
sub_msg = await anext(messages_agen)
# this should only happen is skip_stream flag is removed
# if isinstance(sub_msg, str):
# continue
if is_new:
st.session_state.messages.append(sub_msg)
# Handle tool results with nested popovers
if sub_msg.type == "tool" and sub_msg.tool_call_id in nested_popovers:
popover = nested_popovers[sub_msg.tool_call_id]
popover.write("**Output:**")
popover.write(sub_msg.content)
continue
# Handle transfer_back_to tool calls - these indicate a sub-agent is returning control
if (
hasattr(sub_msg, "tool_calls")
and sub_msg.tool_calls
and any("transfer_back_to" in tc.get("name", "") for tc in sub_msg.tool_calls)
):
# Process transfer_back_to tool calls
for tc in sub_msg.tool_calls:
if "transfer_back_to" in tc.get("name", ""):
# Read the corresponding tool result
transfer_result = await anext(messages_agen)
if is_new:
st.session_state.messages.append(transfer_result)
# After processing transfer back, we're done with this agent
if status:
status.update(state="complete")
break
# Display content and tool calls in the same nested status
if status:
if sub_msg.content:
status.write(sub_msg.content)
if hasattr(sub_msg, "tool_calls") and sub_msg.tool_calls:
for tc in sub_msg.tool_calls:
# Check if this is a nested transfer/delegate
if "transfer_to" in tc["name"]:
# Create a nested status container for the sub-agent
nested_status = status.status(
f"""💼 Sub Agent: {tc["name"]}""",
state="running" if is_new else "complete",
expanded=True,
)
# Recursively handle sub-agents of this sub-agent
await handle_sub_agent_msgs(messages_agen, nested_status, is_new)
else:
# Regular tool call - create popover
popover = status.popover(f"{tc['name']}", icon="🛠️")
popover.write(f"**Tool:** {tc['name']}")
popover.write("**Input:**")
popover.write(tc["args"])
# Store the popover reference using the tool call ID
nested_popovers[tc["id"]] = popover
if __name__ == "__main__":
asyncio.run(main())