Skip to content

Commit 7484e36

Browse files
committed
added gradio interface
1 parent 1c43bb5 commit 7484e36

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,8 @@ old-data/
2323
# outputs
2424
outputs/
2525

26+
# gradio temp
27+
tmp/
28+
2629
# venv
2730
.venv

app.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import atexit
2+
import asyncio
3+
import gc
4+
import shutil
5+
import sqlite3
6+
import uuid
7+
import duckdb
8+
import gradio as gr
9+
from pathlib import Path
10+
from typing import Generator, List, Tuple
11+
12+
from any_chatbot.indexing import embed_and_index_all_docs
13+
from any_chatbot.prompts import system_message
14+
from any_chatbot.tools import initialize_retrieve_tool, initialize_sql_toolkit
15+
from any_chatbot.utils import load_environ_vars
16+
17+
from langchain.chat_models import init_chat_model
18+
from langgraph.checkpoint.sqlite import SqliteSaver
19+
from langgraph.prebuilt import create_react_agent
20+
21+
load_environ_vars()
22+
ROOT = Path(__file__).parent
23+
TMP_DIR = ROOT / "tmp"
24+
TMP_DIR.mkdir(exist_ok=True)
25+
26+
27+
class Session:
28+
def __init__(self):
29+
self.sid = uuid.uuid4().hex
30+
self.db_path = TMP_DIR / "csv_excel_to_db.duckdb"
31+
self.index_path = TMP_DIR / "faiss_index"
32+
self.hist_db_path = TMP_DIR / "hist.db"
33+
self.hist_db = sqlite3.connect(self.hist_db_path, check_same_thread=False)
34+
self.agent = None
35+
self.sql_engines: List = []
36+
37+
def cleanup(self):
38+
# dispose any SQLAlchemy/SQL-toolkit engines
39+
for eng in self.sql_engines:
40+
try:
41+
eng.dispose(close=True)
42+
except Exception:
43+
pass
44+
self.sql_engines.clear()
45+
# close agent history db
46+
try:
47+
self.hist_db.close()
48+
except Exception:
49+
pass
50+
# map agent to none
51+
self.agent = None
52+
# delete the tmp dir itself
53+
shutil.rmtree(TMP_DIR, ignore_errors=True)
54+
55+
56+
# initialize session
57+
sess = Session()
58+
59+
60+
# shutdown hook that is called when session ends
61+
@atexit.register
62+
def _purge_all():
63+
sess.cleanup()
64+
65+
66+
def _safe_copy(src: Path, dst_dir: Path):
67+
dst = dst_dir / src.name
68+
if dst.exists():
69+
dst = dst.with_name(f"{dst.stem}_{uuid.uuid4().hex[:4]}{dst.suffix}")
70+
shutil.copy2(src, dst)
71+
72+
73+
# upload & sync
74+
def cb_upload_and_sync(files: List[gr.File]) -> Generator[Tuple[str, list], None, None]:
75+
# GUARDRAIL FOR EMPTY FILES
76+
if not files:
77+
yield "⚠️ No files selected.", []
78+
return
79+
80+
# RESETTING TMP_DIR
81+
sess.cleanup()
82+
TMP_DIR.mkdir(exist_ok=True)
83+
sess.__init__()
84+
85+
# PREPARE UPLOADED FILES
86+
# copy uploaded files to TMP_DIR
87+
for f in files:
88+
_safe_copy(Path(f.name), TMP_DIR)
89+
yield "📂 Files uploaded. Indexing...", []
90+
# shutdown DuckDB internals
91+
try:
92+
duckdb.shutdown()
93+
except Exception:
94+
pass
95+
# force garbage collection
96+
gc.collect()
97+
98+
# INDEXING
99+
# embedding and indexing uploaded documents
100+
loop = asyncio.new_event_loop()
101+
asyncio.set_event_loop(loop)
102+
_, vector_store = embed_and_index_all_docs(
103+
data_dir=TMP_DIR,
104+
db_path=sess.db_path,
105+
index_path=sess.index_path,
106+
load_data=True,
107+
)
108+
asyncio.set_event_loop(None)
109+
loop.close()
110+
111+
# CREATING AGENT
112+
yield "🤖 Building agent...", []
113+
# build llm
114+
llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
115+
# load tools
116+
retrieve = initialize_retrieve_tool(vector_store)
117+
sql_tools = initialize_sql_toolkit(llm, sess.db_path)
118+
# store on-disk state engines to be properly sess.cleanip() later
119+
for tool in sql_tools:
120+
eng = getattr(tool, "engine", None) or getattr(
121+
getattr(tool, "db", None), "engine", None
122+
)
123+
if eng:
124+
sess.sql_engines.append(eng)
125+
memory = SqliteSaver(sess.hist_db)
126+
# build agent
127+
sess.agent = create_react_agent(
128+
llm, tools=[retrieve, *sql_tools], prompt=system_message, checkpointer=memory
129+
)
130+
yield "✅ Sync complete!", []
131+
132+
133+
# chat
134+
def cb_chat(hist: List[dict], msg: str) -> Tuple[List[dict], str]:
135+
if sess.agent is None:
136+
hist.append(
137+
{
138+
"role": "assistant",
139+
"content": "Please upload files and click 'Upload & Sync' first.",
140+
}
141+
)
142+
return hist, ""
143+
hist.append({"role": "user", "content": msg})
144+
messages = [{"role": m["role"], "content": m["content"]} for m in hist]
145+
reply = ""
146+
for event in sess.agent.stream(
147+
{"messages": messages},
148+
stream_mode="values",
149+
config={"configurable": {"thread_id": sess.sid}},
150+
):
151+
reply = event["messages"][-1].content
152+
hist.append({"role": "assistant", "content": reply})
153+
return hist, ""
154+
155+
156+
# UI
157+
with gr.Blocks(theme="default") as demo:
158+
gr.Markdown("## Learn-Anything Chatbot Agent - Upload and Ask About Your Files")
159+
with gr.Row():
160+
file_box = gr.Files(file_count="multiple", label="Files to upload")
161+
sync_btn = gr.Button("Upload & Sync")
162+
status_md = gr.Markdown()
163+
chatbox = gr.Chatbot(label="Chat", type="messages", height=400)
164+
user_in = gr.Textbox(placeholder="Ask...", scale=8)
165+
166+
sync_btn.click(cb_upload_and_sync, [file_box], [status_md, chatbox])
167+
user_in.submit(cb_chat, [chatbox, user_in], [chatbox, user_in])
168+
169+
if __name__ == "__main__":
170+
demo.queue().launch()

0 commit comments

Comments
 (0)