diff --git a/.env.template b/.env.template index e7ad741..e500db2 100644 --- a/.env.template +++ b/.env.template @@ -41,6 +41,9 @@ COSMOSDB_DATABASE_NAME="template_langgraph" COSMOSDB_CONTAINER_NAME="kabuto" COSMOSDB_PARTITION_KEY="/id" +# SQL Database Settings +SQL_DATABASE_URI="sqlite:///template_langgraph.db" + # --------- # Utilities # --------- diff --git a/.gitignore b/.gitignore index 0a4737f..2e4449d 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ requirements.txt assets/ .langgraph_api generated/ +*.db diff --git a/template_langgraph/services/streamlits/pages/chat_with_tools_agent.py b/template_langgraph/services/streamlits/pages/chat_with_tools_agent.py index 46a1775..f01c749 100644 --- a/template_langgraph/services/streamlits/pages/chat_with_tools_agent.py +++ b/template_langgraph/services/streamlits/pages/chat_with_tools_agent.py @@ -5,22 +5,26 @@ from template_langgraph.agents.chat_with_tools_agent.agent import AgentState, graph +if "chat_history" not in st.session_state: + st.session_state["chat_history"] = [] + +for msg in st.session_state["chat_history"]: + if isinstance(msg, dict): + st.chat_message(msg["role"]).write(msg["content"]) + else: + st.chat_message("assistant").write(msg.content) + if prompt := st.chat_input(): + st.session_state["chat_history"].append({"role": "user", "content": prompt}) st.chat_message("user").write(prompt) with st.chat_message("assistant"): response: AgentState = graph.invoke( - { - "messages": [ - { - "role": "user", - "content": prompt, - }, - ] - }, + {"messages": st.session_state["chat_history"]}, { "callbacks": [ StreamlitCallbackHandler(st.container()), ] }, ) + st.session_state["chat_history"].append(response["messages"][-1]) st.write(response["messages"][-1].content) diff --git a/template_langgraph/tools/common.py b/template_langgraph/tools/common.py index bc320f9..5153789 100644 --- a/template_langgraph/tools/common.py +++ b/template_langgraph/tools/common.py @@ -1,11 +1,15 @@ +from template_langgraph.llms.azure_openais import AzureOpenAiWrapper from template_langgraph.tools.cosmosdb_tool import search_cosmosdb from template_langgraph.tools.dify_tool import run_dify_workflow from template_langgraph.tools.elasticsearch_tool import search_elasticsearch from template_langgraph.tools.qdrant_tool import search_qdrant +from template_langgraph.tools.sql_database_tool import SqlDatabaseClientWrapper DEFAULT_TOOLS = [ search_cosmosdb, run_dify_workflow, search_qdrant, search_elasticsearch, -] +] + SqlDatabaseClientWrapper().get_tools( + llm=AzureOpenAiWrapper().chat_model, +) diff --git a/template_langgraph/tools/sql_database_tool.py b/template_langgraph/tools/sql_database_tool.py new file mode 100644 index 0000000..2ae10d6 --- /dev/null +++ b/template_langgraph/tools/sql_database_tool.py @@ -0,0 +1,45 @@ +from functools import lru_cache + +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools.base import BaseTool +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + sql_database_uri: str = "sqlite:///template_langgraph.db" + + model_config = SettingsConfigDict( + env_file=".env", + env_ignore_empty=True, + extra="ignore", + ) + + +@lru_cache +def get_sql_database_settings() -> Settings: + """Get SQL Database settings.""" + return Settings() + + +class SqlDatabaseClientWrapper: + def __init__( + self, + settings: Settings = None, + ): + if settings is None: + settings = get_sql_database_settings() + self.db = SQLDatabase.from_uri( + database_uri=settings.sql_database_uri, + ) + + def get_tools( + self, + llm: BaseLanguageModel, + ) -> list[BaseTool]: + """Get SQL Database tools.""" + return SQLDatabaseToolkit( + db=self.db, + llm=llm, + ).get_tools()