Skip to content

Commit cf38402

Browse files
Chatbot TUI and DB enhancements (#197)
* feat: save streamed conversation messages to the database * feat: add tui interface with rich * feat: add optional db usage flag * fix: mypy check fixes * fix lint * close db instance * add more logging and improve exception handling * add backend env * seed secret CI with HF_TOKEN --------- Signed-off-by: Palaniappan R <[email protected]> Signed-off-by: Jack Luar <[email protected]> Co-authored-by: Jack Luar <[email protected]>
1 parent aefbcd2 commit cf38402

File tree

4 files changed

+398
-117
lines changed

4 files changed

+398
-117
lines changed

.github/workflows/ci-secret.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ jobs:
4747
sed -i 's|{{GOOGLE_API_KEY}}|${{ secrets.GOOGLE_API_KEY }}|g' backend/.env
4848
sed -i 's|{{GOOGLE_PROJECT_ID}}|${{ secrets.GOOGLE_PROJECT_ID }}|g' backend/.env
4949
sed -i 's|{{PATH_TO_GOOGLE_APPLICATION_CREDENTIALS}}|src/secret.json|g' backend/.env
50+
sed -i 's|HF_TOKEN=|HF_TOKEN=${{ secrets.HF_TOKEN }}|g' backend/.env
5051
5152
cp backend/.env.example evaluation/.env
5253
sed -i 's|{{GOOGLE_API_KEY}}|${{ secrets.GOOGLE_API_KEY }}|g' evaluation/.env
5354
sed -i 's|{{GOOGLE_PROJECT_ID}}|${{ secrets.GOOGLE_PROJECT_ID }}|g' evaluation/.env
5455
sed -i 's|{{PATH_TO_GOOGLE_APPLICATION_CREDENTIALS}}|src/secret.json|g' evaluation/.env
56+
sed -i 's|HF_TOKEN=|HF_TOKEN=${{ secrets.HF_TOKEN }}|g' evaluation/.env
5557
- name: Copy Google credentials
5658
run: |
5759
cp ${{ secrets.PATH_TO_GOOGLE_APPLICATION_CREDENTIALS }} backend/src

backend/chatbot.py

Lines changed: 274 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,286 @@
11
import os
2+
import sys
23
import logging
3-
from src.api.routers import conversations
4+
from uuid import UUID, uuid4
5+
from dotenv import load_dotenv
6+
from rich.console import Console
7+
from rich.panel import Panel
8+
from rich.markdown import Markdown
9+
from rich.prompt import Prompt
10+
from rich.table import Table
11+
from sqlalchemy.orm import Session
412

13+
from src.agents.retriever_graph import RetrieverGraph
14+
from src.database import get_db, init_database
15+
from src.database import crud
516

6-
def get_history_str(chat_history: list[dict[str, str]]) -> str:
7-
history_str = ""
8-
for i in chat_history:
9-
history_str += f"User : {i['User']}\nAI : {i['AI']}\n\n"
10-
return history_str
17+
from langchain_google_vertexai import ChatVertexAI
18+
from langchain_google_genai import ChatGoogleGenerativeAI
19+
from langchain_ollama import ChatOllama
1120

21+
load_dotenv()
22+
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO").upper())
1223

13-
chat_history: list[dict[str, str]] = []
24+
console = Console()
1425

15-
if __name__ == "__main__":
16-
rg = conversations.rg
17-
os.system("clear")
1826

19-
while True:
20-
user_question = input(">>> ")
27+
def setup_llm() -> ChatVertexAI | ChatGoogleGenerativeAI | ChatOllama:
28+
temp = float(os.getenv("LLM_TEMP", "0.0"))
2129

22-
inputs = {
23-
"messages": [
24-
("user", user_question),
25-
],
26-
"chat_history": get_history_str(chat_history=chat_history),
27-
}
30+
if os.getenv("LLM_MODEL") == "ollama":
31+
model = str(os.getenv("OLLAMA_MODEL"))
32+
return ChatOllama(model=model, temperature=temp)
2833

29-
try:
30-
if rg.graph is not None:
31-
output = list(rg.graph.stream(inputs))
34+
elif os.getenv("LLM_MODEL") == "gemini":
35+
gemini = os.getenv("GOOGLE_GEMINI")
36+
if gemini in {"1_pro", "1.5_flash", "1.5_pro"}:
37+
raise ValueError(f"Gemini {gemini} (v1.0-1.5) disabled. Use v2.0+")
38+
elif gemini == "2.0_flash":
39+
return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=temp)
40+
elif gemini == "2.5_flash":
41+
return ChatVertexAI(model_name="gemini-2.5-flash", temperature=temp)
42+
elif gemini == "2.5_pro":
43+
return ChatVertexAI(model_name="gemini-2.5-pro", temperature=temp)
44+
else:
45+
raise ValueError(f"Invalid GOOGLE_GEMINI value: {gemini}")
3246

33-
else:
34-
raise ValueError("RetrieverGraph not initialized.")
35-
except RuntimeError:
36-
logging.error("Runtime Error!")
37-
38-
if (
39-
isinstance(output, list)
40-
and len(output) > 2
41-
and "rag_generate" in output[-1]
42-
and "messages" in output[-1]["rag_generate"]
43-
and len(output[-1]["rag_generate"]["messages"]) > 0
44-
):
45-
llm_response = output[-1]["rag_generate"]["messages"][0]
46-
47-
tool = list(output[-2].keys())[0]
48-
srcs = set(output[-2][tool]["sources"])
49-
urls = output[-2][tool]["urls"]
50-
chat_history.append({"User": user_question, "AI": llm_response})
51-
52-
print(f"LLM: {llm_response} \nSources: {srcs} \nURLs: {urls}\n\n")
53-
54-
elif (
55-
isinstance(output, list)
56-
and len(output) > 2
57-
and "mcp_tools" in output[-1]
58-
and "messages" in output[-1]["mcp_tools"]
59-
):
60-
logging.info(output)
61-
llm_response = output[-1]["mcp_tools"]["messages"][0]
62-
chat_history.append({"User": user_question, "AI": llm_response})
63-
result = output[-1]["mcp_tools"]["messages"]
64-
print()
65-
if len(result) > 0:
66-
for _ in result:
67-
print(_)
68-
else:
69-
print("No Message!")
47+
else:
48+
raise ValueError(f"Invalid LLM_MODEL: {os.getenv('LLM_MODEL')}")
49+
50+
51+
def setup_embeddings() -> dict[str, str]:
52+
embed_type = str(os.getenv("EMBEDDINGS_TYPE"))
53+
54+
if embed_type == "HF":
55+
model = str(os.getenv("HF_EMBEDDINGS"))
56+
elif embed_type in {"GOOGLE_GENAI", "GOOGLE_VERTEXAI"}:
57+
model = str(os.getenv("GOOGLE_EMBEDDINGS"))
58+
else:
59+
raise ValueError(f"Invalid EMBEDDINGS_TYPE: {embed_type}")
60+
61+
return {"type": embed_type, "name": model}
62+
63+
64+
def get_history(
65+
db: Session | None, conv_id: UUID | None, local_history: list[dict]
66+
) -> str:
67+
if db and conv_id:
68+
history = crud.get_conversation_history(db, conv_id)
69+
result = ""
70+
for msg in history:
71+
user = msg.get("User", "")
72+
ai = msg.get("AI", "")
73+
if user and ai:
74+
result += f"User : {user}\nAI : {ai}\n\n"
75+
return result
76+
else:
77+
result = ""
78+
for msg in local_history:
79+
user = msg.get("User", "")
80+
ai = msg.get("AI", "")
81+
if user and ai:
82+
result += f"User : {user}\nAI : {ai}\n\n"
83+
return result
84+
85+
86+
def parse_output(output: list) -> tuple[str, list[str], list[str]]:
87+
fail_msg = "Failed to get response"
88+
MIN_OUTPUT_LENGTH = 3
89+
90+
# Validate outputs
91+
if not isinstance(output, list):
92+
logging.error(f"Output is not a list: {type(output)}")
93+
return fail_msg, [], []
94+
if len(output) < MIN_OUTPUT_LENGTH:
95+
logging.error(f"Output too short: {len(output)} elements (expected >= {MIN_OUTPUT_LENGTH})")
96+
return fail_msg, [], []
97+
98+
# Validate last element contains generation
99+
last = output[-1]
100+
if not isinstance(last, dict):
101+
logging.error(f"Last element is not a dict: {type(last)}")
102+
return fail_msg, [], []
103+
104+
# Determine if RAG or agent mode
105+
is_rag = "rag_generate" in last
106+
key = "rag_generate" if is_rag else "generate"
107+
108+
if key not in last:
109+
logging.error(f"Missing '{key}' key in final output. Available keys: {list(last.keys())}")
110+
return fail_msg, [], []
111+
112+
if "messages" not in last[key]:
113+
logging.error(f"Missing 'messages' in {key}")
114+
return fail_msg, [], []
115+
116+
msgs = last[key]["messages"]
117+
if not msgs:
118+
logging.error("Empty messages list in generation output")
119+
return fail_msg, [], []
120+
121+
response = str(msgs[0])
122+
sources = []
123+
tools = []
124+
125+
if is_rag:
126+
for item in output[1:-1]:
127+
if isinstance(item, dict):
128+
for k, v in item.items():
129+
if k.startswith("retrieve_") and isinstance(v, dict):
130+
urls = v.get("urls", [])
131+
sources.extend(urls)
132+
else:
133+
if "agent" in output[0] and "tools" in output[0]["agent"]:
134+
tools = output[0]["agent"]["tools"]
135+
for i in range(len(tools)):
136+
if i + 1 < len(output):
137+
tool_out = list(output[i + 1].values())[0]
138+
urls = tool_out.get("urls", [])
139+
sources.extend(urls)
140+
141+
# Deduplicate sources
142+
return response, list(set(sources)), tools
143+
144+
145+
def show_response(text: str, sources: list[str], tools: list[str]) -> None:
146+
console.print(
147+
Panel(Markdown(text), title="[bold cyan]Assistant", border_style="cyan")
148+
)
149+
150+
if tools:
151+
console.print(f"[yellow]Tools:[/yellow] {', '.join(tools)}")
152+
153+
if sources:
154+
src_table = Table(title="Sources", show_header=False, border_style="dim")
155+
src_table.add_column("URL", style="blue")
156+
for src in sources:
157+
src_table.add_row(src)
158+
console.print(src_table)
159+
160+
console.print()
161+
162+
163+
def main() -> None:
164+
console.clear()
165+
console.print(
166+
Panel("[bold green]ORAssistant Chatbot[/bold green]", border_style="green")
167+
)
168+
169+
cuda = str(os.getenv("USE_CUDA")).lower() == "true"
170+
fast = str(os.getenv("FAST_MODE")).lower() == "true"
171+
debug = str(os.getenv("DEBUG")).lower() == "true"
172+
mcp = str(os.getenv("ENABLE_MCP")).lower() == "true"
173+
use_db = str(os.getenv("USE_DB", "true")).lower() == "true"
174+
175+
llm = setup_llm()
176+
embed_cfg = setup_embeddings()
177+
reranker = str(os.getenv("HF_RERANKER"))
178+
179+
with console.status("[bold green]Initializing graph...", spinner="dots"):
180+
graph = RetrieverGraph(
181+
llm_model=llm,
182+
embeddings_config=embed_cfg,
183+
reranking_model_name=reranker,
184+
use_cuda=cuda,
185+
inbuilt_tool_calling=True,
186+
fast_mode=fast,
187+
debug=debug,
188+
enable_mcp=mcp,
189+
)
190+
graph.initialize()
191+
192+
if graph.graph is None:
193+
console.print("[bold red]Failed to initialize graph[/bold red]")
194+
sys.exit(1)
195+
196+
db = None
197+
db_generator = None
198+
conv_id = None
199+
local_history: list[dict[str, str]] = []
200+
201+
if use_db:
202+
if init_database():
203+
db_generator = get_db()
204+
db = next(db_generator)
205+
conv_id = uuid4()
206+
crud.create_conversation(db, conversation_uuid=conv_id, title=None)
207+
console.print("[dim]Database: enabled[/dim]")
70208
else:
71-
logging.info(output)
72-
print("LLM response extraction failed")
209+
console.print("[yellow]Database unavailable, using local memory[/yellow]")
210+
use_db = False
211+
else:
212+
console.print("[dim]Database: disabled[/dim]")
213+
214+
console.print("[dim]Type 'exit' or 'quit' to end session[/dim]\n")
215+
216+
try:
217+
while True:
218+
query = Prompt.ask("[bold blue]You[/bold blue]")
219+
220+
if query.lower() in {"exit", "quit", "q"}:
221+
console.print("[yellow]Goodbye![/yellow]")
222+
break
223+
224+
if not query.strip():
225+
continue
226+
227+
if use_db and db and conv_id:
228+
crud.create_message(
229+
db=db,
230+
conversation_uuid=conv_id,
231+
role="user",
232+
content=query,
233+
)
234+
235+
inputs = {
236+
"messages": [("user", query)],
237+
"chat_history": get_history(db, conv_id, local_history),
238+
}
239+
240+
with console.status("[bold green]Thinking...", spinner="dots"):
241+
output = list(graph.graph.stream(inputs, stream_mode="updates"))
242+
243+
response, sources, tools = parse_output(output)
244+
245+
if use_db and db and conv_id:
246+
ctx_srcs = {"sources": [{"source": s, "context": ""} for s in sources]}
247+
crud.create_message(
248+
db=db,
249+
conversation_uuid=conv_id,
250+
role="assistant",
251+
content=response,
252+
context_sources=ctx_srcs,
253+
tools=tools,
254+
)
255+
else:
256+
local_history.append({"User": query, "AI": response})
257+
258+
show_response(response, sources, tools)
259+
260+
except KeyboardInterrupt:
261+
console.print("\n[yellow]Interrupted. Goodbye![/yellow]")
262+
except ValueError as e:
263+
console.print(f"[bold red]Configuration Error:[/bold red] {str(e)}")
264+
console.print("[yellow]Check your environment variables and try again.[/yellow]")
265+
if debug:
266+
logging.exception("Configuration error")
267+
except ConnectionError as e:
268+
console.print(f"[bold red]Connection Error:[/bold red] {str(e)}")
269+
console.print("[yellow]Check your network connection and database availability.[/yellow]")
270+
if debug:
271+
logging.exception("Connection error")
272+
except Exception as e:
273+
console.print(f"[bold red]Unexpected Error:[/bold red] {str(e)}")
274+
console.print("[yellow]Please report this issue if it persists.[/yellow]")
275+
logging.exception("Unexpected error in main loop")
276+
finally:
277+
# Clean up database session
278+
if db_generator is not None:
279+
try:
280+
next(db_generator)
281+
except StopIteration:
282+
pass
283+
284+
285+
if __name__ == "__main__":
286+
main()

0 commit comments

Comments
 (0)