Skip to content

Commit 46789b4

Browse files
committed
added persistent memory and parse args
1 parent 07f0717 commit 46789b4

File tree

6 files changed

+126
-56
lines changed

6 files changed

+126
-56
lines changed

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ langgraph
1010
unstructured[pdf,docx,pptx,md,image]
1111
duckdb
1212
duckdb-engine
13-
openpyxl
13+
openpyxl
14+
faiss-cpu
15+
langgraph-checkpoint-sqlite

scripts/run_agent.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
python -m any_chatbot.agent "$@"

src/any_chatbot/agent.py

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,111 @@
1-
import getpass
21
import os
2+
import argparse
33
import random
4+
import sqlite3
45
from dotenv import load_dotenv
56
from pathlib import Path
67

78
from langgraph.prebuilt import create_react_agent
8-
from langgraph.checkpoint.memory import MemorySaver
9+
from langgraph.checkpoint.sqlite import SqliteSaver
910
from langchain.chat_models import init_chat_model
1011

1112
from any_chatbot.indexing import embed_and_index_all_docs
1213
from any_chatbot.tools import initialize_retrieve_tool, initialize_sql_toolkit
1314
from any_chatbot.prompts import system_message
15+
from any_chatbot.utils import load_environ_vars
1416

1517
load_dotenv()
1618

1719
BASE = Path(__file__).parent.parent.parent
18-
DATA = BASE / "data"
19-
OUTPUTS = BASE / "outputs"
20-
DATABASE = DATA / "csv_excel_to_db" / "my_data.duckdb"
2120

22-
# INDEXING
23-
embeddings, vector_store = embed_and_index_all_docs(DATA, DATABASE)
2421

25-
# BUILD LLM
26-
if not os.environ.get("GOOGLE_API_KEY"):
27-
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
28-
llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
22+
def parse_args() -> argparse.Namespace:
23+
"""Parse command-line arguments for training and evaluation."""
24+
p = argparse.ArgumentParser()
2925

30-
# LOAD TOOLS
31-
retrieve_tool = initialize_retrieve_tool(vector_store)
32-
sql_tools = initialize_sql_toolkit(llm, DATABASE)
26+
p.add_argument(
27+
"--ask",
28+
type=str,
29+
default=(
30+
"What kinds (images, text docs, or excel sheets) are available in the documents I have provided to you? Use the functional call to retrieve information for each type first.\n\n"
31+
# "What colums does the excel have? once you found the answer, tell me there types too.\n\n"
32+
# "Once you have that answer, I want you to calculate the median for each column.\n\n"
33+
"When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n"
34+
"You can refine your semantic search queries and try multiple times with different queries until you resonably determine the data is not available on the given documents.\n\n"
35+
"Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES"
36+
),
37+
help="Your input to agent",
38+
)
39+
p.add_argument(
40+
"--data_dir",
41+
type=Path,
42+
default=BASE / "data",
43+
help="Path to data dir where your files are uploaded",
44+
)
45+
p.add_argument(
46+
"--thread_id",
47+
type=str,
48+
default=str(random.random()),
49+
help="Your conversation history ID. Different IDs save different chat histories with agent",
50+
)
51+
p.add_argument(
52+
"--outputs_dir",
53+
type=Path,
54+
default=BASE / "outputs",
55+
help="Path to output dir where image of agent architecture is saved",
56+
)
57+
p.add_argument(
58+
"--database_dir",
59+
type=Path,
60+
default=BASE / "data" / "generated_db" / "csv_excel_to_db.duckdb",
61+
help="Path to database dir where the sql version of CSV/EXCEL files are stored",
62+
)
63+
return p.parse_args()
3364

34-
# BUILD AGENT
35-
# build checkpointer
36-
memory = MemorySaver()
37-
# build agent
38-
agent_executor = create_react_agent(
39-
llm, [retrieve_tool, *sql_tools], prompt=system_message, checkpointer=memory
40-
)
41-
# save architecture graph image
42-
png_bytes = agent_executor.get_graph().draw_mermaid_png()
43-
# save to file
44-
with open(OUTPUTS / "graph.png", "wb") as f:
45-
f.write(png_bytes)
46-
print("Created graph.png")
4765

48-
# PROMPT
49-
# specify an ID for the thread
50-
# config = {"configurable": {"thread_id": "abc123"}}
51-
config = {"configurable": {"thread_id": random.random()}}
66+
def main() -> None:
67+
cfg = parse_args()
68+
load_environ_vars()
69+
# INDEXING
70+
_, vector_store = embed_and_index_all_docs(cfg.data_dir, cfg.database_dir)
5271

53-
# input_message = (
54-
# "What is the content of the image?\n\n"
55-
# "When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n"
56-
# "Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES"
57-
# )
72+
# BUILD LLM
73+
llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
5874

59-
input_message = (
60-
"What kinds (images, text docs, or excel sheets) are available in the documents I have provided to you? Use the functional call to retrieve information for each type first.\n\n"
61-
# "What colums does the excel have? once you found the answer, tell me there types too.\n\n"
62-
# "Once you have that answer, I want you to calculate the median for each column.\n\n"
63-
"When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n"
64-
"You can refine your semantic search queries and try multiple times with different queries until you resonably determine the data is not available on the given documents.\n\n"
65-
"Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES"
66-
)
75+
# LOAD TOOLS
76+
retrieve_tool = initialize_retrieve_tool(vector_store)
77+
sql_tools = initialize_sql_toolkit(llm, cfg.database_dir)
6778

68-
for event in agent_executor.stream(
69-
{"messages": [{"role": "user", "content": input_message}]},
70-
stream_mode="values",
71-
config=config,
72-
):
73-
event["messages"][-1].pretty_print()
79+
# BUILD AGENT
80+
# build persistent checkpointer
81+
con = sqlite3.connect(
82+
cfg.data_dir / "generated_db" / "agent_history.db", check_same_thread=False
83+
)
84+
memory = SqliteSaver(con)
85+
# build agent
86+
agent_executor = create_react_agent(
87+
llm, [retrieve_tool, *sql_tools], prompt=system_message, checkpointer=memory
88+
)
89+
# save architecture graph image
90+
png_bytes = agent_executor.get_graph().draw_mermaid_png()
91+
# ensure the output folder exists
92+
os.makedirs(cfg.outputs_dir, exist_ok=True)
93+
# save to file
94+
with open(cfg.outputs_dir / "graph.png", "wb") as f:
95+
f.write(png_bytes)
96+
print("Created graph.png")
97+
98+
# PROMPT
99+
# specify an ID for the thread
100+
config = {"configurable": {"thread_id": cfg.thread_id}}
101+
# stream conversation
102+
for event in agent_executor.stream(
103+
{"messages": [{"role": "user", "content": cfg.ask}]},
104+
stream_mode="values",
105+
config=config,
106+
):
107+
event["messages"][-1].pretty_print()
108+
109+
110+
if __name__ == "__main__":
111+
main()

src/any_chatbot/indexing.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import re
33
import pandas as pd
44
import duckdb
5+
import shutil
56
from dotenv import load_dotenv
67
from pathlib import Path
78

8-
from langchain_core.vectorstores import InMemoryVectorStore
9+
from langchain_community.vectorstores import FAISS
910
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1011
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader
1112
from langchain_text_splitters import RecursiveCharacterTextSplitter
@@ -188,12 +189,22 @@ def build_duckdb_and_summary_cards(
188189
return summary_cards
189190

190191

192+
def reset_faiss_index(index_path: Path):
193+
if index_path.exists():
194+
print("Reseting previous index...")
195+
shutil.rmtree(index_path)
196+
197+
191198
def embed_and_index_all_docs(
192-
data_dir: Path = DATA, db_path: Path = DATA / "csv_excel_to_db" / "my_data.duckdb"
199+
data_dir: Path = DATA,
200+
db_path: Path = DATA / "generated_db" / "csv_excel_to_db.duckdb",
201+
index_path: Path = DATA / "generated_db" / "faiss_index",
193202
):
203+
# delete old FAISS index if it exists
204+
reset_faiss_index(index_path)
205+
194206
# load embeedings and vector store
195207
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
196-
vector_store = InMemoryVectorStore(embeddings)
197208

198209
# LOAD AND SPLIT TEXT DOCS
199210
text_chunks = load_and_split_text_docs(data_dir)
@@ -202,5 +213,10 @@ def embed_and_index_all_docs(
202213
# LOAD AND SPLIT CSV/EXCEL DOCS
203214
summary_cards = build_duckdb_and_summary_cards(data_dir, db_path)
204215

205-
vector_store.add_documents(text_chunks + image_text_docs + summary_cards)
216+
# vector_store.add_documents(text_chunks + image_text_docs + summary_cards)
217+
vector_store = FAISS.from_documents(
218+
text_chunks + image_text_docs + summary_cards, embeddings
219+
)
220+
vector_store.save_local(index_path)
221+
206222
return embeddings, vector_store

src/any_chatbot/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def retrieve(
2929
retrieved_docs = vector_store.similarity_search(
3030
query,
3131
k=2,
32-
filter=lambda doc: doc.metadata.get("source_type") == tag,
32+
filter={"source_type": tag},
3333
)
3434
serialized = "\n\n".join(
3535
(f"Source: {doc.metadata}\nContent: {doc.page_content}")

src/any_chatbot/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import getpass
2+
import os
3+
4+
5+
def load_environ_vars() -> None:
6+
"""Set basic environment variables needed for a run."""
7+
if not os.environ.get("GOOGLE_API_KEY"):
8+
os.environ["GOOGLE_API_KEY"] = getpass.getpass(
9+
"Enter API key for Google Gemini: "
10+
)

0 commit comments

Comments
 (0)