Skip to content

Commit ef5dae3

Browse files
committed
added sql query toolkit for agent
1 parent 9297bf0 commit ef5dae3

File tree

4 files changed

+55
-49
lines changed

4 files changed

+55
-49
lines changed

src/any_chatbot/agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
from langchain.chat_models import init_chat_model
1010

1111
from any_chatbot.indexing import embed_and_index_all_docs
12-
from any_chatbot.tools import initialize_retrieve_tool
12+
from any_chatbot.tools import initialize_retrieve_tool, initialize_sql_toolkit
13+
from any_chatbot.prompts import system_message
1314

1415
load_dotenv()
1516

1617
BASE = Path(__file__).parent.parent.parent
1718
DATA = BASE / "data"
1819
OUTPUTS = BASE / "outputs"
20+
DATABASE = DATA / "csv_excel_to_db" / "my_data.duckdb"
1921

2022
# INDEXING
21-
embeddings, vector_store = embed_and_index_all_docs(DATA)
23+
embeddings, vector_store = embed_and_index_all_docs(DATA, DATABASE)
2224

2325
# BUILD LLM
2426
if not os.environ.get("GOOGLE_API_KEY"):
@@ -27,12 +29,15 @@
2729

2830
# LOAD TOOLS
2931
retrieve_tool = initialize_retrieve_tool(vector_store)
32+
sql_tools = initialize_sql_toolkit(llm, DATABASE)
3033

3134
# BUILD AGENT
3235
# build checkpointer
3336
memory = MemorySaver()
3437
# build agent
35-
agent_executor = create_react_agent(llm, [retrieve_tool], checkpointer=memory)
38+
agent_executor = create_react_agent(
39+
llm, [retrieve_tool, *sql_tools], prompt=system_message, checkpointer=memory
40+
)
3641
# save architecture graph image
3742
png_bytes = agent_executor.get_graph().draw_mermaid_png()
3843
# save to file
@@ -52,8 +57,11 @@
5257
# )
5358

5459
input_message = (
55-
"What colums does the excel have? once you found the answer, tell me there types too.\n\n"
60+
"How many employees were working for Nike? The informaton is in the pdf.\n\n"
61+
# "What colums does the excel have? once you found the answer, tell me there types too.\n\n"
62+
# "Once you have that answer, I want you to calculate the median for each column.\n\n"
5663
"When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n"
64+
"You can refine your semantic search queries and try multiple times with different queries until you resonably determine the data is not available on the given documents.\n\n"
5765
"Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES"
5866
)
5967

src/any_chatbot/indexing.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def load_and_split_text_docs(data_dir):
3838
glob=globs,
3939
loader_cls=UnstructuredFileLoader,
4040
)
41-
print(f"Loading files from {data_dir}")
41+
print(f"Loading text files from {data_dir}")
4242
docs = loader.load()
4343
print(f"Loaded {len(docs)} text files")
4444
# split
@@ -204,47 +204,3 @@ def embed_and_index_all_docs(
204204

205205
vector_store.add_documents(text_chunks + image_text_docs + summary_cards)
206206
return embeddings, vector_store
207-
208-
209-
# def index_text_docs(
210-
# data_pth: Path = DATA,
211-
# ):
212-
# embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
213-
# vector_store = InMemoryVectorStore(embeddings)
214-
215-
# # Load the text documents
216-
# loader = DirectoryLoader(
217-
# str(data_pth),
218-
# glob=[
219-
# "**/*.pdf",
220-
# "**/*.docx",
221-
# "**/*.pptx",
222-
# "**/*.md",
223-
# "**/*.html",
224-
# "**/*.txt",
225-
# "**/*.png",
226-
# "**/*.jpg",
227-
# "**/*.jpeg",
228-
# "**/*.tiff",
229-
# ],
230-
# loader_cls=UnstructuredFileLoader,
231-
# )
232-
# print(f"Loading files from {data_pth}")
233-
# docs = loader.load()
234-
# print(f"Loaded {len(docs)} files")
235-
236-
# # Split the texts
237-
# text_splitter = RecursiveCharacterTextSplitter(
238-
# chunk_size=1000,
239-
# chunk_overlap=200,
240-
# add_start_index=True,
241-
# separators=["\n\n", "\n", " ", ""],
242-
# )
243-
# all_splits = text_splitter.split_documents(docs)
244-
# print(len(all_splits))
245-
246-
# # index the docs
247-
# ids = vector_store.add_documents(documents=all_splits)
248-
# print(len(ids))
249-
250-
# return embeddings, vector_store

src/any_chatbot/prompts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
system_message = """
2+
You are an agent designed to interact with a SQL database.
3+
Given an input question, create a syntactically correct {dialect} query to run,
4+
then look at the results of the query and return the answer. Unless the user
5+
specifies a specific number of examples they wish to obtain, always limit your
6+
query to at most {top_k} results.
7+
8+
You can order the results by a relevant column to return the most interesting
9+
examples in the database. Never query for all the columns from a specific table,
10+
only ask for the relevant columns given the question.
11+
12+
You MUST double check your query before executing it. If you get an error while
13+
executing a query, rewrite the query and try again.
14+
15+
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
16+
database.
17+
18+
To start you should ALWAYS look at the tables in the database to see what you
19+
can query. Do NOT skip this step.
20+
21+
Then you should query the schema of the most relevant tables.
22+
""".format(
23+
dialect="DuckDB",
24+
top_k=5,
25+
)

src/any_chatbot/tools.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from typing import Tuple, List
2+
from pathlib import Path
3+
24
from langchain_core.tools import tool
35
from langchain.vectorstores.base import VectorStore
46
from langchain.schema import Document
7+
from langchain_community.utilities.sql_database import SQLDatabase
8+
from langchain_community.agent_toolkits import SQLDatabaseToolkit
9+
10+
BASE = Path(__file__).parent.parent.parent
11+
DATA = BASE / "data"
512

613

714
def initialize_retrieve_tool(vector_store: VectorStore):
@@ -18,3 +25,13 @@ def retrieve(query: str) -> Tuple[str, List[Document]]:
1825
return serialized, retrieved_docs
1926

2027
return retrieve
28+
29+
30+
def initialize_sql_toolkit(
31+
llm,
32+
db_path: Path = DATA / "csv_excel_to_db" / "my_data.duckdb",
33+
):
34+
db = SQLDatabase.from_uri(f"duckdb:///{db_path}")
35+
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
36+
tools = toolkit.get_tools()
37+
return tools

0 commit comments

Comments
 (0)