|
1 | | -import getpass |
2 | 1 | import os |
| 2 | +import argparse |
3 | 3 | import random |
| 4 | +import sqlite3 |
4 | 5 | from dotenv import load_dotenv |
5 | 6 | from pathlib import Path |
6 | 7 |
|
7 | 8 | from langgraph.prebuilt import create_react_agent |
8 | | -from langgraph.checkpoint.memory import MemorySaver |
| 9 | +from langgraph.checkpoint.sqlite import SqliteSaver |
9 | 10 | from langchain.chat_models import init_chat_model |
10 | 11 |
|
11 | 12 | from any_chatbot.indexing import embed_and_index_all_docs |
12 | 13 | from any_chatbot.tools import initialize_retrieve_tool, initialize_sql_toolkit |
13 | 14 | from any_chatbot.prompts import system_message |
| 15 | +from any_chatbot.utils import load_environ_vars |
14 | 16 |
|
15 | 17 | load_dotenv() |
16 | 18 |
|
17 | 19 | BASE = Path(__file__).parent.parent.parent |
18 | | -DATA = BASE / "data" |
19 | | -OUTPUTS = BASE / "outputs" |
20 | | -DATABASE = DATA / "csv_excel_to_db" / "my_data.duckdb" |
21 | 20 |
|
22 | | -# INDEXING |
23 | | -embeddings, vector_store = embed_and_index_all_docs(DATA, DATABASE) |
24 | 21 |
|
25 | | -# BUILD LLM |
26 | | -if not os.environ.get("GOOGLE_API_KEY"): |
27 | | - os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") |
28 | | -llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai") |
| 22 | +def parse_args() -> argparse.Namespace: |
| 23 | + """Parse command-line arguments for training and evaluation.""" |
| 24 | + p = argparse.ArgumentParser() |
29 | 25 |
|
30 | | -# LOAD TOOLS |
31 | | -retrieve_tool = initialize_retrieve_tool(vector_store) |
32 | | -sql_tools = initialize_sql_toolkit(llm, DATABASE) |
| 26 | + p.add_argument( |
| 27 | + "--ask", |
| 28 | + type=str, |
| 29 | + default=( |
| 30 | + "What kinds (images, text docs, or excel sheets) are available in the documents I have provided to you? Use the functional call to retrieve information for each type first.\n\n" |
| 31 | + # "What colums does the excel have? once you found the answer, tell me there types too.\n\n" |
| 32 | + # "Once you have that answer, I want you to calculate the median for each column.\n\n" |
| 33 | + "When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n" |
| 34 | + "You can refine your semantic search queries and try multiple times with different queries until you resonably determine the data is not available on the given documents.\n\n" |
| 35 | + "Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES" |
| 36 | + ), |
| 37 | + help="Your input to agent", |
| 38 | + ) |
| 39 | + p.add_argument( |
| 40 | + "--data_dir", |
| 41 | + type=Path, |
| 42 | + default=BASE / "data", |
| 43 | + help="Path to data dir where your files are uploaded", |
| 44 | + ) |
| 45 | + p.add_argument( |
| 46 | + "--thread_id", |
| 47 | + type=str, |
| 48 | + default=str(random.random()), |
| 49 | + help="Your conversation history ID. Different IDs save different chat histories with agent", |
| 50 | + ) |
| 51 | + p.add_argument( |
| 52 | + "--outputs_dir", |
| 53 | + type=Path, |
| 54 | + default=BASE / "outputs", |
| 55 | + help="Path to output dir where image of agent architecture is saved", |
| 56 | + ) |
| 57 | + p.add_argument( |
| 58 | + "--database_dir", |
| 59 | + type=Path, |
| 60 | + default=BASE / "data" / "generated_db" / "csv_excel_to_db.duckdb", |
| 61 | + help="Path to database dir where the sql version of CSV/EXCEL files are stored", |
| 62 | + ) |
| 63 | + return p.parse_args() |
33 | 64 |
|
34 | | -# BUILD AGENT |
35 | | -# build checkpointer |
36 | | -memory = MemorySaver() |
37 | | -# build agent |
38 | | -agent_executor = create_react_agent( |
39 | | - llm, [retrieve_tool, *sql_tools], prompt=system_message, checkpointer=memory |
40 | | -) |
41 | | -# save architecture graph image |
42 | | -png_bytes = agent_executor.get_graph().draw_mermaid_png() |
43 | | -# save to file |
44 | | -with open(OUTPUTS / "graph.png", "wb") as f: |
45 | | - f.write(png_bytes) |
46 | | -print("Created graph.png") |
47 | 65 |
|
48 | | -# PROMPT |
49 | | -# specify an ID for the thread |
50 | | -# config = {"configurable": {"thread_id": "abc123"}} |
51 | | -config = {"configurable": {"thread_id": random.random()}} |
| 66 | +def main() -> None: |
| 67 | + cfg = parse_args() |
| 68 | + load_environ_vars() |
| 69 | + # INDEXING |
| 70 | + _, vector_store = embed_and_index_all_docs(cfg.data_dir, cfg.database_dir) |
52 | 71 |
|
53 | | -# input_message = ( |
54 | | -# "What is the content of the image?\n\n" |
55 | | -# "When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n" |
56 | | -# "Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES" |
57 | | -# ) |
| 72 | + # BUILD LLM |
| 73 | + llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai") |
58 | 74 |
|
59 | | -input_message = ( |
60 | | - "What kinds (images, text docs, or excel sheets) are available in the documents I have provided to you? Use the functional call to retrieve information for each type first.\n\n" |
61 | | - # "What colums does the excel have? once you found the answer, tell me there types too.\n\n" |
62 | | - # "Once you have that answer, I want you to calculate the median for each column.\n\n" |
63 | | - "When you don't know while files the user is talking about, use the functional call to retrieve what data is available with a general prompt.\n\n" |
64 | | - "You can refine your semantic search queries and try multiple times with different queries until you resonably determine the data is not available on the given documents.\n\n" |
65 | | - "Base your answers only on the retrieved information thorugh the functional call you have. You can retreive MULTIPLE TIMES" |
66 | | -) |
| 75 | + # LOAD TOOLS |
| 76 | + retrieve_tool = initialize_retrieve_tool(vector_store) |
| 77 | + sql_tools = initialize_sql_toolkit(llm, cfg.database_dir) |
67 | 78 |
|
68 | | -for event in agent_executor.stream( |
69 | | - {"messages": [{"role": "user", "content": input_message}]}, |
70 | | - stream_mode="values", |
71 | | - config=config, |
72 | | -): |
73 | | - event["messages"][-1].pretty_print() |
| 79 | + # BUILD AGENT |
| 80 | + # build persistent checkpointer |
| 81 | + con = sqlite3.connect( |
| 82 | + cfg.data_dir / "generated_db" / "agent_history.db", check_same_thread=False |
| 83 | + ) |
| 84 | + memory = SqliteSaver(con) |
| 85 | + # build agent |
| 86 | + agent_executor = create_react_agent( |
| 87 | + llm, [retrieve_tool, *sql_tools], prompt=system_message, checkpointer=memory |
| 88 | + ) |
| 89 | + # save architecture graph image |
| 90 | + png_bytes = agent_executor.get_graph().draw_mermaid_png() |
| 91 | + # ensure the output folder exists |
| 92 | + os.makedirs(cfg.outputs_dir, exist_ok=True) |
| 93 | + # save to file |
| 94 | + with open(cfg.outputs_dir / "graph.png", "wb") as f: |
| 95 | + f.write(png_bytes) |
| 96 | + print("Created graph.png") |
| 97 | + |
| 98 | + # PROMPT |
| 99 | + # specify an ID for the thread |
| 100 | + config = {"configurable": {"thread_id": cfg.thread_id}} |
| 101 | + # stream conversation |
| 102 | + for event in agent_executor.stream( |
| 103 | + {"messages": [{"role": "user", "content": cfg.ask}]}, |
| 104 | + stream_mode="values", |
| 105 | + config=config, |
| 106 | + ): |
| 107 | + event["messages"][-1].pretty_print() |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + main() |
0 commit comments