|
1 | 1 | import os |
| 2 | +import sys |
2 | 3 | import logging |
3 | | -from src.api.routers import conversations |
| 4 | +from uuid import UUID, uuid4 |
| 5 | +from dotenv import load_dotenv |
| 6 | +from rich.console import Console |
| 7 | +from rich.panel import Panel |
| 8 | +from rich.markdown import Markdown |
| 9 | +from rich.prompt import Prompt |
| 10 | +from rich.table import Table |
| 11 | +from sqlalchemy.orm import Session |
4 | 12 |
|
| 13 | +from src.agents.retriever_graph import RetrieverGraph |
| 14 | +from src.database import get_db, init_database |
| 15 | +from src.database import crud |
5 | 16 |
|
6 | | -def get_history_str(chat_history: list[dict[str, str]]) -> str: |
7 | | - history_str = "" |
8 | | - for i in chat_history: |
9 | | - history_str += f"User : {i['User']}\nAI : {i['AI']}\n\n" |
10 | | - return history_str |
| 17 | +from langchain_google_vertexai import ChatVertexAI |
| 18 | +from langchain_google_genai import ChatGoogleGenerativeAI |
| 19 | +from langchain_ollama import ChatOllama |
11 | 20 |
|
| 21 | +load_dotenv() |
| 22 | +logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO").upper()) |
12 | 23 |
|
13 | | -chat_history: list[dict[str, str]] = [] |
| 24 | +console = Console() |
14 | 25 |
|
15 | | -if __name__ == "__main__": |
16 | | - rg = conversations.rg |
17 | | - os.system("clear") |
18 | 26 |
|
19 | | - while True: |
20 | | - user_question = input(">>> ") |
| 27 | +def setup_llm() -> ChatVertexAI | ChatGoogleGenerativeAI | ChatOllama: |
| 28 | + temp = float(os.getenv("LLM_TEMP", "0.0")) |
21 | 29 |
|
22 | | - inputs = { |
23 | | - "messages": [ |
24 | | - ("user", user_question), |
25 | | - ], |
26 | | - "chat_history": get_history_str(chat_history=chat_history), |
27 | | - } |
| 30 | + if os.getenv("LLM_MODEL") == "ollama": |
| 31 | + model = str(os.getenv("OLLAMA_MODEL")) |
| 32 | + return ChatOllama(model=model, temperature=temp) |
28 | 33 |
|
29 | | - try: |
30 | | - if rg.graph is not None: |
31 | | - output = list(rg.graph.stream(inputs)) |
| 34 | + elif os.getenv("LLM_MODEL") == "gemini": |
| 35 | + gemini = os.getenv("GOOGLE_GEMINI") |
| 36 | + if gemini in {"1_pro", "1.5_flash", "1.5_pro"}: |
| 37 | + raise ValueError(f"Gemini {gemini} (v1.0-1.5) disabled. Use v2.0+") |
| 38 | + elif gemini == "2.0_flash": |
| 39 | + return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=temp) |
| 40 | + elif gemini == "2.5_flash": |
| 41 | + return ChatVertexAI(model_name="gemini-2.5-flash", temperature=temp) |
| 42 | + elif gemini == "2.5_pro": |
| 43 | + return ChatVertexAI(model_name="gemini-2.5-pro", temperature=temp) |
| 44 | + else: |
| 45 | + raise ValueError(f"Invalid GOOGLE_GEMINI value: {gemini}") |
32 | 46 |
|
33 | | - else: |
34 | | - raise ValueError("RetrieverGraph not initialized.") |
35 | | - except RuntimeError: |
36 | | - logging.error("Runtime Error!") |
37 | | - |
38 | | - if ( |
39 | | - isinstance(output, list) |
40 | | - and len(output) > 2 |
41 | | - and "rag_generate" in output[-1] |
42 | | - and "messages" in output[-1]["rag_generate"] |
43 | | - and len(output[-1]["rag_generate"]["messages"]) > 0 |
44 | | - ): |
45 | | - llm_response = output[-1]["rag_generate"]["messages"][0] |
46 | | - |
47 | | - tool = list(output[-2].keys())[0] |
48 | | - srcs = set(output[-2][tool]["sources"]) |
49 | | - urls = output[-2][tool]["urls"] |
50 | | - chat_history.append({"User": user_question, "AI": llm_response}) |
51 | | - |
52 | | - print(f"LLM: {llm_response} \nSources: {srcs} \nURLs: {urls}\n\n") |
53 | | - |
54 | | - elif ( |
55 | | - isinstance(output, list) |
56 | | - and len(output) > 2 |
57 | | - and "mcp_tools" in output[-1] |
58 | | - and "messages" in output[-1]["mcp_tools"] |
59 | | - ): |
60 | | - logging.info(output) |
61 | | - llm_response = output[-1]["mcp_tools"]["messages"][0] |
62 | | - chat_history.append({"User": user_question, "AI": llm_response}) |
63 | | - result = output[-1]["mcp_tools"]["messages"] |
64 | | - print() |
65 | | - if len(result) > 0: |
66 | | - for _ in result: |
67 | | - print(_) |
68 | | - else: |
69 | | - print("No Message!") |
| 47 | + else: |
| 48 | + raise ValueError(f"Invalid LLM_MODEL: {os.getenv('LLM_MODEL')}") |
| 49 | + |
| 50 | + |
| 51 | +def setup_embeddings() -> dict[str, str]: |
| 52 | + embed_type = str(os.getenv("EMBEDDINGS_TYPE")) |
| 53 | + |
| 54 | + if embed_type == "HF": |
| 55 | + model = str(os.getenv("HF_EMBEDDINGS")) |
| 56 | + elif embed_type in {"GOOGLE_GENAI", "GOOGLE_VERTEXAI"}: |
| 57 | + model = str(os.getenv("GOOGLE_EMBEDDINGS")) |
| 58 | + else: |
| 59 | + raise ValueError(f"Invalid EMBEDDINGS_TYPE: {embed_type}") |
| 60 | + |
| 61 | + return {"type": embed_type, "name": model} |
| 62 | + |
| 63 | + |
| 64 | +def get_history( |
| 65 | + db: Session | None, conv_id: UUID | None, local_history: list[dict] |
| 66 | +) -> str: |
| 67 | + if db and conv_id: |
| 68 | + history = crud.get_conversation_history(db, conv_id) |
| 69 | + result = "" |
| 70 | + for msg in history: |
| 71 | + user = msg.get("User", "") |
| 72 | + ai = msg.get("AI", "") |
| 73 | + if user and ai: |
| 74 | + result += f"User : {user}\nAI : {ai}\n\n" |
| 75 | + return result |
| 76 | + else: |
| 77 | + result = "" |
| 78 | + for msg in local_history: |
| 79 | + user = msg.get("User", "") |
| 80 | + ai = msg.get("AI", "") |
| 81 | + if user and ai: |
| 82 | + result += f"User : {user}\nAI : {ai}\n\n" |
| 83 | + return result |
| 84 | + |
| 85 | + |
| 86 | +def parse_output(output: list) -> tuple[str, list[str], list[str]]: |
| 87 | + fail_msg = "Failed to get response" |
| 88 | + MIN_OUTPUT_LENGTH = 3 |
| 89 | + |
| 90 | + # Validate outputs |
| 91 | + if not isinstance(output, list): |
| 92 | + logging.error(f"Output is not a list: {type(output)}") |
| 93 | + return fail_msg, [], [] |
| 94 | + if len(output) < MIN_OUTPUT_LENGTH: |
| 95 | + logging.error(f"Output too short: {len(output)} elements (expected >= {MIN_OUTPUT_LENGTH})") |
| 96 | + return fail_msg, [], [] |
| 97 | + |
| 98 | + # Validate last element contains generation |
| 99 | + last = output[-1] |
| 100 | + if not isinstance(last, dict): |
| 101 | + logging.error(f"Last element is not a dict: {type(last)}") |
| 102 | + return fail_msg, [], [] |
| 103 | + |
| 104 | + # Determine if RAG or agent mode |
| 105 | + is_rag = "rag_generate" in last |
| 106 | + key = "rag_generate" if is_rag else "generate" |
| 107 | + |
| 108 | + if key not in last: |
| 109 | + logging.error(f"Missing '{key}' key in final output. Available keys: {list(last.keys())}") |
| 110 | + return fail_msg, [], [] |
| 111 | + |
| 112 | + if "messages" not in last[key]: |
| 113 | + logging.error(f"Missing 'messages' in {key}") |
| 114 | + return fail_msg, [], [] |
| 115 | + |
| 116 | + msgs = last[key]["messages"] |
| 117 | + if not msgs: |
| 118 | + logging.error("Empty messages list in generation output") |
| 119 | + return fail_msg, [], [] |
| 120 | + |
| 121 | + response = str(msgs[0]) |
| 122 | + sources = [] |
| 123 | + tools = [] |
| 124 | + |
| 125 | + if is_rag: |
| 126 | + for item in output[1:-1]: |
| 127 | + if isinstance(item, dict): |
| 128 | + for k, v in item.items(): |
| 129 | + if k.startswith("retrieve_") and isinstance(v, dict): |
| 130 | + urls = v.get("urls", []) |
| 131 | + sources.extend(urls) |
| 132 | + else: |
| 133 | + if "agent" in output[0] and "tools" in output[0]["agent"]: |
| 134 | + tools = output[0]["agent"]["tools"] |
| 135 | + for i in range(len(tools)): |
| 136 | + if i + 1 < len(output): |
| 137 | + tool_out = list(output[i + 1].values())[0] |
| 138 | + urls = tool_out.get("urls", []) |
| 139 | + sources.extend(urls) |
| 140 | + |
| 141 | + # Deduplicate sources |
| 142 | + return response, list(set(sources)), tools |
| 143 | + |
| 144 | + |
| 145 | +def show_response(text: str, sources: list[str], tools: list[str]) -> None: |
| 146 | + console.print( |
| 147 | + Panel(Markdown(text), title="[bold cyan]Assistant", border_style="cyan") |
| 148 | + ) |
| 149 | + |
| 150 | + if tools: |
| 151 | + console.print(f"[yellow]Tools:[/yellow] {', '.join(tools)}") |
| 152 | + |
| 153 | + if sources: |
| 154 | + src_table = Table(title="Sources", show_header=False, border_style="dim") |
| 155 | + src_table.add_column("URL", style="blue") |
| 156 | + for src in sources: |
| 157 | + src_table.add_row(src) |
| 158 | + console.print(src_table) |
| 159 | + |
| 160 | + console.print() |
| 161 | + |
| 162 | + |
| 163 | +def main() -> None: |
| 164 | + console.clear() |
| 165 | + console.print( |
| 166 | + Panel("[bold green]ORAssistant Chatbot[/bold green]", border_style="green") |
| 167 | + ) |
| 168 | + |
| 169 | + cuda = str(os.getenv("USE_CUDA")).lower() == "true" |
| 170 | + fast = str(os.getenv("FAST_MODE")).lower() == "true" |
| 171 | + debug = str(os.getenv("DEBUG")).lower() == "true" |
| 172 | + mcp = str(os.getenv("ENABLE_MCP")).lower() == "true" |
| 173 | + use_db = str(os.getenv("USE_DB", "true")).lower() == "true" |
| 174 | + |
| 175 | + llm = setup_llm() |
| 176 | + embed_cfg = setup_embeddings() |
| 177 | + reranker = str(os.getenv("HF_RERANKER")) |
| 178 | + |
| 179 | + with console.status("[bold green]Initializing graph...", spinner="dots"): |
| 180 | + graph = RetrieverGraph( |
| 181 | + llm_model=llm, |
| 182 | + embeddings_config=embed_cfg, |
| 183 | + reranking_model_name=reranker, |
| 184 | + use_cuda=cuda, |
| 185 | + inbuilt_tool_calling=True, |
| 186 | + fast_mode=fast, |
| 187 | + debug=debug, |
| 188 | + enable_mcp=mcp, |
| 189 | + ) |
| 190 | + graph.initialize() |
| 191 | + |
| 192 | + if graph.graph is None: |
| 193 | + console.print("[bold red]Failed to initialize graph[/bold red]") |
| 194 | + sys.exit(1) |
| 195 | + |
| 196 | + db = None |
| 197 | + db_generator = None |
| 198 | + conv_id = None |
| 199 | + local_history: list[dict[str, str]] = [] |
| 200 | + |
| 201 | + if use_db: |
| 202 | + if init_database(): |
| 203 | + db_generator = get_db() |
| 204 | + db = next(db_generator) |
| 205 | + conv_id = uuid4() |
| 206 | + crud.create_conversation(db, conversation_uuid=conv_id, title=None) |
| 207 | + console.print("[dim]Database: enabled[/dim]") |
70 | 208 | else: |
71 | | - logging.info(output) |
72 | | - print("LLM response extraction failed") |
| 209 | + console.print("[yellow]Database unavailable, using local memory[/yellow]") |
| 210 | + use_db = False |
| 211 | + else: |
| 212 | + console.print("[dim]Database: disabled[/dim]") |
| 213 | + |
| 214 | + console.print("[dim]Type 'exit' or 'quit' to end session[/dim]\n") |
| 215 | + |
| 216 | + try: |
| 217 | + while True: |
| 218 | + query = Prompt.ask("[bold blue]You[/bold blue]") |
| 219 | + |
| 220 | + if query.lower() in {"exit", "quit", "q"}: |
| 221 | + console.print("[yellow]Goodbye![/yellow]") |
| 222 | + break |
| 223 | + |
| 224 | + if not query.strip(): |
| 225 | + continue |
| 226 | + |
| 227 | + if use_db and db and conv_id: |
| 228 | + crud.create_message( |
| 229 | + db=db, |
| 230 | + conversation_uuid=conv_id, |
| 231 | + role="user", |
| 232 | + content=query, |
| 233 | + ) |
| 234 | + |
| 235 | + inputs = { |
| 236 | + "messages": [("user", query)], |
| 237 | + "chat_history": get_history(db, conv_id, local_history), |
| 238 | + } |
| 239 | + |
| 240 | + with console.status("[bold green]Thinking...", spinner="dots"): |
| 241 | + output = list(graph.graph.stream(inputs, stream_mode="updates")) |
| 242 | + |
| 243 | + response, sources, tools = parse_output(output) |
| 244 | + |
| 245 | + if use_db and db and conv_id: |
| 246 | + ctx_srcs = {"sources": [{"source": s, "context": ""} for s in sources]} |
| 247 | + crud.create_message( |
| 248 | + db=db, |
| 249 | + conversation_uuid=conv_id, |
| 250 | + role="assistant", |
| 251 | + content=response, |
| 252 | + context_sources=ctx_srcs, |
| 253 | + tools=tools, |
| 254 | + ) |
| 255 | + else: |
| 256 | + local_history.append({"User": query, "AI": response}) |
| 257 | + |
| 258 | + show_response(response, sources, tools) |
| 259 | + |
| 260 | + except KeyboardInterrupt: |
| 261 | + console.print("\n[yellow]Interrupted. Goodbye![/yellow]") |
| 262 | + except ValueError as e: |
| 263 | + console.print(f"[bold red]Configuration Error:[/bold red] {str(e)}") |
| 264 | + console.print("[yellow]Check your environment variables and try again.[/yellow]") |
| 265 | + if debug: |
| 266 | + logging.exception("Configuration error") |
| 267 | + except ConnectionError as e: |
| 268 | + console.print(f"[bold red]Connection Error:[/bold red] {str(e)}") |
| 269 | + console.print("[yellow]Check your network connection and database availability.[/yellow]") |
| 270 | + if debug: |
| 271 | + logging.exception("Connection error") |
| 272 | + except Exception as e: |
| 273 | + console.print(f"[bold red]Unexpected Error:[/bold red] {str(e)}") |
| 274 | + console.print("[yellow]Please report this issue if it persists.[/yellow]") |
| 275 | + logging.exception("Unexpected error in main loop") |
| 276 | + finally: |
| 277 | + # Clean up database session |
| 278 | + if db_generator is not None: |
| 279 | + try: |
| 280 | + next(db_generator) |
| 281 | + except StopIteration: |
| 282 | + pass |
| 283 | + |
| 284 | + |
| 285 | +if __name__ == "__main__": |
| 286 | + main() |
0 commit comments