|
| 1 | +import atexit |
| 2 | +import asyncio |
| 3 | +import gc |
| 4 | +import shutil |
| 5 | +import sqlite3 |
| 6 | +import uuid |
| 7 | +import duckdb |
| 8 | +import gradio as gr |
| 9 | +from pathlib import Path |
| 10 | +from typing import Generator, List, Tuple |
| 11 | + |
| 12 | +from any_chatbot.indexing import embed_and_index_all_docs |
| 13 | +from any_chatbot.prompts import system_message |
| 14 | +from any_chatbot.tools import initialize_retrieve_tool, initialize_sql_toolkit |
| 15 | +from any_chatbot.utils import load_environ_vars |
| 16 | + |
| 17 | +from langchain.chat_models import init_chat_model |
| 18 | +from langgraph.checkpoint.sqlite import SqliteSaver |
| 19 | +from langgraph.prebuilt import create_react_agent |
| 20 | + |
| 21 | +load_environ_vars() |
| 22 | +ROOT = Path(__file__).parent |
| 23 | +TMP_DIR = ROOT / "tmp" |
| 24 | +TMP_DIR.mkdir(exist_ok=True) |
| 25 | + |
| 26 | + |
| 27 | +class Session: |
| 28 | + def __init__(self): |
| 29 | + self.sid = uuid.uuid4().hex |
| 30 | + self.db_path = TMP_DIR / "csv_excel_to_db.duckdb" |
| 31 | + self.index_path = TMP_DIR / "faiss_index" |
| 32 | + self.hist_db_path = TMP_DIR / "hist.db" |
| 33 | + self.hist_db = sqlite3.connect(self.hist_db_path, check_same_thread=False) |
| 34 | + self.agent = None |
| 35 | + self.sql_engines: List = [] |
| 36 | + |
| 37 | + def cleanup(self): |
| 38 | + # dispose any SQLAlchemy/SQL-toolkit engines |
| 39 | + for eng in self.sql_engines: |
| 40 | + try: |
| 41 | + eng.dispose(close=True) |
| 42 | + except Exception: |
| 43 | + pass |
| 44 | + self.sql_engines.clear() |
| 45 | + # close agent history db |
| 46 | + try: |
| 47 | + self.hist_db.close() |
| 48 | + except Exception: |
| 49 | + pass |
| 50 | + # map agent to none |
| 51 | + self.agent = None |
| 52 | + # delete the tmp dir itself |
| 53 | + shutil.rmtree(TMP_DIR, ignore_errors=True) |
| 54 | + |
| 55 | + |
| 56 | +# initialize session |
| 57 | +sess = Session() |
| 58 | + |
| 59 | + |
| 60 | +# shutdown hook that is called when session ends |
| 61 | +@atexit.register |
| 62 | +def _purge_all(): |
| 63 | + sess.cleanup() |
| 64 | + |
| 65 | + |
| 66 | +def _safe_copy(src: Path, dst_dir: Path): |
| 67 | + dst = dst_dir / src.name |
| 68 | + if dst.exists(): |
| 69 | + dst = dst.with_name(f"{dst.stem}_{uuid.uuid4().hex[:4]}{dst.suffix}") |
| 70 | + shutil.copy2(src, dst) |
| 71 | + |
| 72 | + |
| 73 | +# upload & sync |
| 74 | +def cb_upload_and_sync(files: List[gr.File]) -> Generator[Tuple[str, list], None, None]: |
| 75 | + # GUARDRAIL FOR EMPTY FILES |
| 76 | + if not files: |
| 77 | + yield "⚠️ No files selected.", [] |
| 78 | + return |
| 79 | + |
| 80 | + # RESETTING TMP_DIR |
| 81 | + sess.cleanup() |
| 82 | + TMP_DIR.mkdir(exist_ok=True) |
| 83 | + sess.__init__() |
| 84 | + |
| 85 | + # PREPARE UPLOADED FILES |
| 86 | + # copy uploaded files to TMP_DIR |
| 87 | + for f in files: |
| 88 | + _safe_copy(Path(f.name), TMP_DIR) |
| 89 | + yield "📂 Files uploaded. Indexing...", [] |
| 90 | + # shutdown DuckDB internals |
| 91 | + try: |
| 92 | + duckdb.shutdown() |
| 93 | + except Exception: |
| 94 | + pass |
| 95 | + # force garbage collection |
| 96 | + gc.collect() |
| 97 | + |
| 98 | + # INDEXING |
| 99 | + # embedding and indexing uploaded documents |
| 100 | + loop = asyncio.new_event_loop() |
| 101 | + asyncio.set_event_loop(loop) |
| 102 | + _, vector_store = embed_and_index_all_docs( |
| 103 | + data_dir=TMP_DIR, |
| 104 | + db_path=sess.db_path, |
| 105 | + index_path=sess.index_path, |
| 106 | + load_data=True, |
| 107 | + ) |
| 108 | + asyncio.set_event_loop(None) |
| 109 | + loop.close() |
| 110 | + |
| 111 | + # CREATING AGENT |
| 112 | + yield "🤖 Building agent...", [] |
| 113 | + # build llm |
| 114 | + llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai") |
| 115 | + # load tools |
| 116 | + retrieve = initialize_retrieve_tool(vector_store) |
| 117 | + sql_tools = initialize_sql_toolkit(llm, sess.db_path) |
| 118 | + # store on-disk state engines to be properly sess.cleanip() later |
| 119 | + for tool in sql_tools: |
| 120 | + eng = getattr(tool, "engine", None) or getattr( |
| 121 | + getattr(tool, "db", None), "engine", None |
| 122 | + ) |
| 123 | + if eng: |
| 124 | + sess.sql_engines.append(eng) |
| 125 | + memory = SqliteSaver(sess.hist_db) |
| 126 | + # build agent |
| 127 | + sess.agent = create_react_agent( |
| 128 | + llm, tools=[retrieve, *sql_tools], prompt=system_message, checkpointer=memory |
| 129 | + ) |
| 130 | + yield "✅ Sync complete!", [] |
| 131 | + |
| 132 | + |
| 133 | +# chat |
| 134 | +def cb_chat(hist: List[dict], msg: str) -> Tuple[List[dict], str]: |
| 135 | + if sess.agent is None: |
| 136 | + hist.append( |
| 137 | + { |
| 138 | + "role": "assistant", |
| 139 | + "content": "Please upload files and click 'Upload & Sync' first.", |
| 140 | + } |
| 141 | + ) |
| 142 | + return hist, "" |
| 143 | + hist.append({"role": "user", "content": msg}) |
| 144 | + messages = [{"role": m["role"], "content": m["content"]} for m in hist] |
| 145 | + reply = "" |
| 146 | + for event in sess.agent.stream( |
| 147 | + {"messages": messages}, |
| 148 | + stream_mode="values", |
| 149 | + config={"configurable": {"thread_id": sess.sid}}, |
| 150 | + ): |
| 151 | + reply = event["messages"][-1].content |
| 152 | + hist.append({"role": "assistant", "content": reply}) |
| 153 | + return hist, "" |
| 154 | + |
| 155 | + |
| 156 | +# UI |
| 157 | +with gr.Blocks(theme="default") as demo: |
| 158 | + gr.Markdown("## Learn-Anything Chatbot Agent - Upload and Ask About Your Files") |
| 159 | + with gr.Row(): |
| 160 | + file_box = gr.Files(file_count="multiple", label="Files to upload") |
| 161 | + sync_btn = gr.Button("Upload & Sync") |
| 162 | + status_md = gr.Markdown() |
| 163 | + chatbox = gr.Chatbot(label="Chat", type="messages", height=400) |
| 164 | + user_in = gr.Textbox(placeholder="Ask...", scale=8) |
| 165 | + |
| 166 | + sync_btn.click(cb_upload_and_sync, [file_box], [status_md, chatbox]) |
| 167 | + user_in.submit(cb_chat, [chatbox, user_in], [chatbox, user_in]) |
| 168 | + |
| 169 | +if __name__ == "__main__": |
| 170 | + demo.queue().launch() |
0 commit comments