Skip to content

Commit 9c260c4

Browse files
authored
Merge pull request #55 from ks6088ts-labs/feature/issue-48_sql
add sql database tool
2 parents 230992a + 54267ef commit 9c260c4

File tree

5 files changed

+66
-9
lines changed

5 files changed

+66
-9
lines changed

.env.template

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ COSMOSDB_DATABASE_NAME="template_langgraph"
4141
COSMOSDB_CONTAINER_NAME="kabuto"
4242
COSMOSDB_PARTITION_KEY="/id"
4343

44+
# SQL Database Settings
45+
SQL_DATABASE_URI="sqlite:///template_langgraph.db"
46+
4447
# ---------
4548
# Utilities
4649
# ---------

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,4 @@ requirements.txt
165165
assets/
166166
.langgraph_api
167167
generated/
168+
*.db

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,26 @@
55

66
from template_langgraph.agents.chat_with_tools_agent.agent import AgentState, graph
77

8+
if "chat_history" not in st.session_state:
9+
st.session_state["chat_history"] = []
10+
11+
for msg in st.session_state["chat_history"]:
12+
if isinstance(msg, dict):
13+
st.chat_message(msg["role"]).write(msg["content"])
14+
else:
15+
st.chat_message("assistant").write(msg.content)
16+
817
if prompt := st.chat_input():
18+
st.session_state["chat_history"].append({"role": "user", "content": prompt})
919
st.chat_message("user").write(prompt)
1020
with st.chat_message("assistant"):
1121
response: AgentState = graph.invoke(
12-
{
13-
"messages": [
14-
{
15-
"role": "user",
16-
"content": prompt,
17-
},
18-
]
19-
},
22+
{"messages": st.session_state["chat_history"]},
2023
{
2124
"callbacks": [
2225
StreamlitCallbackHandler(st.container()),
2326
]
2427
},
2528
)
29+
st.session_state["chat_history"].append(response["messages"][-1])
2630
st.write(response["messages"][-1].content)

template_langgraph/tools/common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
12
from template_langgraph.tools.cosmosdb_tool import search_cosmosdb
23
from template_langgraph.tools.dify_tool import run_dify_workflow
34
from template_langgraph.tools.elasticsearch_tool import search_elasticsearch
45
from template_langgraph.tools.qdrant_tool import search_qdrant
6+
from template_langgraph.tools.sql_database_tool import SqlDatabaseClientWrapper
57

68
DEFAULT_TOOLS = [
79
search_cosmosdb,
810
run_dify_workflow,
911
search_qdrant,
1012
search_elasticsearch,
11-
]
13+
] + SqlDatabaseClientWrapper().get_tools(
14+
llm=AzureOpenAiWrapper().chat_model,
15+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from functools import lru_cache
2+
3+
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
4+
from langchain_community.utilities.sql_database import SQLDatabase
5+
from langchain_core.language_models import BaseLanguageModel
6+
from langchain_core.tools.base import BaseTool
7+
from pydantic_settings import BaseSettings, SettingsConfigDict
8+
9+
10+
class Settings(BaseSettings):
11+
sql_database_uri: str = "sqlite:///template_langgraph.db"
12+
13+
model_config = SettingsConfigDict(
14+
env_file=".env",
15+
env_ignore_empty=True,
16+
extra="ignore",
17+
)
18+
19+
20+
@lru_cache
21+
def get_sql_database_settings() -> Settings:
22+
"""Get SQL Database settings."""
23+
return Settings()
24+
25+
26+
class SqlDatabaseClientWrapper:
27+
def __init__(
28+
self,
29+
settings: Settings = None,
30+
):
31+
if settings is None:
32+
settings = get_sql_database_settings()
33+
self.db = SQLDatabase.from_uri(
34+
database_uri=settings.sql_database_uri,
35+
)
36+
37+
def get_tools(
38+
self,
39+
llm: BaseLanguageModel,
40+
) -> list[BaseTool]:
41+
"""Get SQL Database tools."""
42+
return SQLDatabaseToolkit(
43+
db=self.db,
44+
llm=llm,
45+
).get_tools()

0 commit comments

Comments
 (0)