11"""Adapted from https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/basic_memory.py"""
22
3+ import sqlite3
34import tempfile
45from os import environ
56from random import getrandbits
67
78import streamlit as st
89from langchain_community .agent_toolkits .sql .toolkit import SQLDatabaseToolkit
910from langchain_community .utilities .sql_database import SQLDatabase
10- from langchain_core .messages import SystemMessage
11+ from langchain_core .messages import HumanMessage , SystemMessage
1112from langchain_core .messages .base import BaseMessage
1213from langchain_core .runnables .config import (
1314 RunnableConfig ,
3536title = "LangGraph SQL Agent Demo"
3637st .set_page_config (page_title = title , page_icon = "📖" , layout = "wide" )
3738st .title (title )
39+ _streamlit_helpers .styles ()
3840
3941
4042model = ChatVertexAI (
@@ -75,12 +77,19 @@ def search(query: str):
7577
7678
7779system_prompt = SystemMessage (
78- content = "You are a helpful AI assistant with a mastery of database management and querying. You have "
79- "access to an ephemeral sqlite database that you can query and modify through some tools. Help "
80- "answer questions and perform actions."
81- "\n "
82- "Make sure you always use QuerySQLCheckerTool to validate queries before running them! If you "
83- "make a mistake, try to recover."
80+ content = f"""\
81+ You are a careful and helpful AI assistant with a mastery of database design and querying. You
82+ have access to an ephemeral sqlite3 database that you can query and modify through some tools.
83+ Help answer questions and perform actions. Follow these rules:
84+
85+ - Make sure you always use sql_db_query_checker to validate SQL statements **before** running
86+ them! In pseudocode: `checked_query = sql_db_query_checker(query);
87+ sql_db_query(checked_query)`.
88+ - The sqlite version is { sqlite3 .sqlite_version } which supports multiple row inserts.
89+ - Always prefer to insert multiple rows in a single call to the sql_db_query tool, if possible.
90+ - You may request to execute multiple sql_db_query tool calls which will be run in parallel.
91+
92+ If you make a mistake, try to recover."""
8493)
8594
8695
@@ -119,35 +128,49 @@ def get_db(thread_id: str) -> SQLDatabase:
119128else :
120129 messages = []
121130
131+
132+ @st .cache_resource
133+ def get_trace_ids (thread_id : str ) -> "dict[str, str]" :
134+ # Stores the trace IDs. Unfortunately I can't find a way to easily retrieve this from the
135+ # checkpointer, so just store it separately.
136+ return {}
137+
138+
139+ trace_ids = get_trace_ids (st .query_params .thread_id )
140+
122141col1 , col2 = st .columns ([0.6 , 0.4 ])
123142with col1 :
124143 _streamlit_helpers .render_intro ()
125144 st .divider ()
126145
127146 # Add system message
128- st .chat_message (
129- "human " , avatar = ":material/precision_manufacturing:"
130- ).markdown (f"**System Instructions** \n > { system_prompt .content } " )
147+ st .expander (
148+ "System Instructions " , icon = ":material/precision_manufacturing:"
149+ ).markdown (system_prompt .content )
131150
132151 # Render current messages
133- for msg in messages :
134- # Filter out tool calls
135- if msg .type in ("human" , "ai" ) and msg .content :
136- st .chat_message (msg .type ).write (msg .content )
152+ for message in messages :
153+ trace_id = trace_ids .get (message .id or "" )
154+ _streamlit_helpers .render_message (message , trace_id )
137155
138156# If user inputs a new prompt, generate and draw a new response
139157if prompt := st .chat_input ():
158+ message = HumanMessage (prompt )
140159 with col1 :
141- st .chat_message ("human" ).write (prompt )
142160 with tracer .start_as_current_span (
143161 "chain invoke" ,
144162 attributes = {"thread_id" : st .query_params .thread_id },
145- ) as span , st . spinner ( "Thinking..." ) :
146- st . toast (
147- f"Trace ID { format_trace_id ( span . get_span_context (). trace_id ) } "
148- )
163+ ) as span :
164+ trace_id = format_trace_id ( span . get_span_context (). trace_id )
165+ _streamlit_helpers . render_message ( message , trace_id = trace_id )
166+
149167 # Invoke the agent
150- app .invoke ({"messages" : [prompt ]}, config = config )
168+ with st .spinner ("Thinking..." ):
169+ res = app .invoke ({"messages" : [message ]}, config = config )
170+
171+ # Store trace ID for rendering
172+ trace_ids [message .id or "" ] = trace_id
173+ trace_ids [res ["messages" ][- 1 ].id ] = trace_id
151174
152175 st .rerun ()
153176
@@ -160,6 +183,3 @@ def get_db(thread_id: str) -> SQLDatabase:
160183
161184 with st .expander ("View the message contents in session state" ):
162185 st .json (messages )
163-
164- # https://pantheon.corp.google.com/traces/explorer;traceId=8a69802b84077b162277f175e0f15276;duration=PT30M?project=otlp-test-deleteme
165- # https://pantheon.corp.google.com/traces/explorer;query=%7B%22plotType%22:%22HEATMAP%22,%22targetAxis%22:%22Y1%22,%22traceQuery%22:%7B%22resourceContainer%22:%22projects%2Fotlp-test-deleteme%2Flocations%2Fglobal%2FtraceScopes%2F_Default%22,%22spanDataValue%22:%22SPAN_DURATION%22,%22spanFilters%22:%7B%22attributes%22:%5B%5D,%22displayNames%22:%5B%22chain%20invoke%22%5D,%22isRootSpan%22:true,%22kinds%22:%5B%5D,%22maxDuration%22:%22%22,%22minDuration%22:%22%22,%22services%22:%5B%5D,%22status%22:%5B%5D%7D%7D%7D;traceId=8a69802b84077b162277f175e0f15276;duration=PT30M
0 commit comments