Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
3c504da
chore: update deps
palaniappan-r Oct 20, 2025
346a6ce
feat: add session id to input object
palaniappan-r Oct 21, 2025
7b453ca
feat: implement database models and CRUD operations
palaniappan-r Oct 21, 2025
175fa14
feat: add conversations API with DB integration
palaniappan-r Oct 21, 2025
970f26e
feat: add api docs
palaniappan-r Oct 21, 2025
b2b9e22
chore: move api docs
palaniappan-r Oct 21, 2025
046e80d
feat: delete unused chains api
palaniappan-r Oct 21, 2025
f618d05
feat: add database configuration to .env.example
palaniappan-r Oct 21, 2025
eab4938
chore: update deps
palaniappan-r Oct 21, 2025
faca754
feat: update dockerfile
palaniappan-r Oct 24, 2025
0eb2c0c
feat: Add postgresql and pgadmin services
palaniappan-r Oct 24, 2025
4b75461
feat: add response models for conversations endpoints
palaniappan-r Oct 24, 2025
d3660f5
chore: update deps
palaniappan-r Oct 24, 2025
0e02868
feat: include conversations router
palaniappan-r Oct 24, 2025
30759dc
feat: initialize database module
palaniappan-r Oct 24, 2025
045383a
fix: Expose core components in package init
palaniappan-r Oct 30, 2025
108d70d
feat: load environment variables and configure CUDA settings
palaniappan-r Nov 2, 2025
adc9ea9
chore: update deps
palaniappan-r Nov 2, 2025
a4b7ca4
fix: update embeddings type
palaniappan-r Nov 2, 2025
7b8ba17
refactor: move CUDA environment variable setup to main entry point
palaniappan-r Nov 2, 2025
6e4b3bf
fix: update healthcheck command to include database name
palaniappan-r Nov 2, 2025
7b6db80
refactor: replace session_id with conversation_id in db models
palaniappan-r Nov 2, 2025
a8c5f8e
refactor: update type hints
palaniappan-r Nov 2, 2025
83537bc
refactor: formatting changes
palaniappan-r Nov 2, 2025
4e85fbd
fix: handle incomplete message pairs
palaniappan-r Nov 2, 2025
ebf88f8
fix: remove hardcoded PostgreSQL environment variables
palaniappan-r Nov 2, 2025
d635686
Merge branch 'master' into backend-db
palaniappan-r Nov 2, 2025
2e9d22a
feat: update psycopg2 dependency to psycopg2-binary
palaniappan-r Nov 2, 2025
fbb0307
fix: update import from graphs to conversations
palaniappan-r Nov 2, 2025
6dd183e
chore: update google-cloud-storage dependency
palaniappan-r Nov 2, 2025
e8cf6f2
fix: use retriever graph
palaniappan-r Nov 2, 2025
43ac935
fix: revert google-cloud-storage dependency
palaniappan-r Nov 2, 2025
d28d45e
fix: add return type hints
palaniappan-r Nov 2, 2025
8c8ef8e
fix: add return type hint to lifespan function
palaniappan-r Nov 2, 2025
f14f31f
docs: add docstrings for endpoints
palaniappan-r Nov 6, 2025
76c7bc0
fix: untrack data folder
palaniappan-r Nov 7, 2025
7e28524
docs: update README with PostgreSQL setup instructions
palaniappan-r Nov 7, 2025
8cb48bb
remove api_docs
luarss Nov 7, 2025
2551efa
remove unused response_models
luarss Nov 7, 2025
b4d4c74
db models:
luarss Nov 7, 2025
9438d9a
specific docker-compose commands
luarss Nov 7, 2025
115a603
improve history string robustness
luarss Nov 7, 2025
e7a3146
fix LLM response extraction (docker compose)
luarss Nov 7, 2025
989df79
add unit tests
luarss Nov 7, 2025
8f32f66
reduce docstring verbosity
luarss Nov 7, 2025
dcbbb7f
feat: save streamed conversation messages to the database
palaniappan-r Nov 12, 2025
ad28a80
feat: add tui interface with rich
palaniappan-r Nov 13, 2025
1cf7ae1
feat: add optional db usage flag
palaniappan-r Nov 27, 2025
4042bbd
Merge branch 'master' into chatbot-tui
palaniappan-r Nov 27, 2025
8540ba4
fix: mypy check fixes
palaniappan-r Nov 27, 2025
7bf5820
fix lint
luarss Dec 2, 2025
a6404e6
close db instance
luarss Dec 2, 2025
1f47e1e
add more logging and improve exception handling
luarss Dec 2, 2025
b82865f
add backend env
luarss Dec 4, 2025
10fe935
seed secret CI with HF_TOKEN
luarss Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-secret.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ jobs:
sed -i 's|{{GOOGLE_API_KEY}}|${{ secrets.GOOGLE_API_KEY }}|g' backend/.env
sed -i 's|{{GOOGLE_PROJECT_ID}}|${{ secrets.GOOGLE_PROJECT_ID }}|g' backend/.env
sed -i 's|{{PATH_TO_GOOGLE_APPLICATION_CREDENTIALS}}|src/secret.json|g' backend/.env
sed -i 's|HF_TOKEN=|HF_TOKEN=${{ secrets.HF_TOKEN }}|g' backend/.env

cp backend/.env.example evaluation/.env
sed -i 's|{{GOOGLE_API_KEY}}|${{ secrets.GOOGLE_API_KEY }}|g' evaluation/.env
sed -i 's|{{GOOGLE_PROJECT_ID}}|${{ secrets.GOOGLE_PROJECT_ID }}|g' evaluation/.env
sed -i 's|{{PATH_TO_GOOGLE_APPLICATION_CREDENTIALS}}|src/secret.json|g' evaluation/.env
sed -i 's|HF_TOKEN=|HF_TOKEN=${{ secrets.HF_TOKEN }}|g' evaluation/.env
- name: Copy Google credentials
run: |
cp ${{ secrets.PATH_TO_GOOGLE_APPLICATION_CREDENTIALS }} backend/src
Expand Down
334 changes: 274 additions & 60 deletions backend/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,286 @@
import os
import sys
import logging
from src.api.routers import conversations
from uuid import UUID, uuid4
from dotenv import load_dotenv
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.prompt import Prompt
from rich.table import Table
from sqlalchemy.orm import Session

from src.agents.retriever_graph import RetrieverGraph
from src.database import get_db, init_database
from src.database import crud

def get_history_str(chat_history: list[dict[str, str]]) -> str:
history_str = ""
for i in chat_history:
history_str += f"User : {i['User']}\nAI : {i['AI']}\n\n"
return history_str
from langchain_google_vertexai import ChatVertexAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama

load_dotenv()
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO").upper())

chat_history: list[dict[str, str]] = []
console = Console()

if __name__ == "__main__":
rg = conversations.rg
os.system("clear")

while True:
user_question = input(">>> ")
def setup_llm() -> ChatVertexAI | ChatGoogleGenerativeAI | ChatOllama:
temp = float(os.getenv("LLM_TEMP", "0.0"))

inputs = {
"messages": [
("user", user_question),
],
"chat_history": get_history_str(chat_history=chat_history),
}
if os.getenv("LLM_MODEL") == "ollama":
model = str(os.getenv("OLLAMA_MODEL"))
return ChatOllama(model=model, temperature=temp)

try:
if rg.graph is not None:
output = list(rg.graph.stream(inputs))
elif os.getenv("LLM_MODEL") == "gemini":
gemini = os.getenv("GOOGLE_GEMINI")
if gemini in {"1_pro", "1.5_flash", "1.5_pro"}:
raise ValueError(f"Gemini {gemini} (v1.0-1.5) disabled. Use v2.0+")
elif gemini == "2.0_flash":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=temp)
elif gemini == "2.5_flash":
return ChatVertexAI(model_name="gemini-2.5-flash", temperature=temp)
elif gemini == "2.5_pro":
return ChatVertexAI(model_name="gemini-2.5-pro", temperature=temp)
else:
raise ValueError(f"Invalid GOOGLE_GEMINI value: {gemini}")

else:
raise ValueError("RetrieverGraph not initialized.")
except RuntimeError:
logging.error("Runtime Error!")

if (
isinstance(output, list)
and len(output) > 2
and "rag_generate" in output[-1]
and "messages" in output[-1]["rag_generate"]
and len(output[-1]["rag_generate"]["messages"]) > 0
):
llm_response = output[-1]["rag_generate"]["messages"][0]

tool = list(output[-2].keys())[0]
srcs = set(output[-2][tool]["sources"])
urls = output[-2][tool]["urls"]
chat_history.append({"User": user_question, "AI": llm_response})

print(f"LLM: {llm_response} \nSources: {srcs} \nURLs: {urls}\n\n")

elif (
isinstance(output, list)
and len(output) > 2
and "mcp_tools" in output[-1]
and "messages" in output[-1]["mcp_tools"]
):
logging.info(output)
llm_response = output[-1]["mcp_tools"]["messages"][0]
chat_history.append({"User": user_question, "AI": llm_response})
result = output[-1]["mcp_tools"]["messages"]
print()
if len(result) > 0:
for _ in result:
print(_)
else:
print("No Message!")
else:
raise ValueError(f"Invalid LLM_MODEL: {os.getenv('LLM_MODEL')}")


def setup_embeddings() -> dict[str, str]:
embed_type = str(os.getenv("EMBEDDINGS_TYPE"))

if embed_type == "HF":
model = str(os.getenv("HF_EMBEDDINGS"))
elif embed_type in {"GOOGLE_GENAI", "GOOGLE_VERTEXAI"}:
model = str(os.getenv("GOOGLE_EMBEDDINGS"))
else:
raise ValueError(f"Invalid EMBEDDINGS_TYPE: {embed_type}")

return {"type": embed_type, "name": model}


def get_history(
db: Session | None, conv_id: UUID | None, local_history: list[dict]
) -> str:
if db and conv_id:
history = crud.get_conversation_history(db, conv_id)
result = ""
for msg in history:
user = msg.get("User", "")
ai = msg.get("AI", "")
if user and ai:
result += f"User : {user}\nAI : {ai}\n\n"
return result
else:
result = ""
for msg in local_history:
user = msg.get("User", "")
ai = msg.get("AI", "")
if user and ai:
result += f"User : {user}\nAI : {ai}\n\n"
return result


def parse_output(output: list) -> tuple[str, list[str], list[str]]:
fail_msg = "Failed to get response"
MIN_OUTPUT_LENGTH = 3

# Validate outputs
if not isinstance(output, list):
logging.error(f"Output is not a list: {type(output)}")
return fail_msg, [], []
if len(output) < MIN_OUTPUT_LENGTH:
logging.error(f"Output too short: {len(output)} elements (expected >= {MIN_OUTPUT_LENGTH})")
return fail_msg, [], []

# Validate last element contains generation
last = output[-1]
if not isinstance(last, dict):
logging.error(f"Last element is not a dict: {type(last)}")
return fail_msg, [], []

# Determine if RAG or agent mode
is_rag = "rag_generate" in last
key = "rag_generate" if is_rag else "generate"

if key not in last:
logging.error(f"Missing '{key}' key in final output. Available keys: {list(last.keys())}")
return fail_msg, [], []

if "messages" not in last[key]:
logging.error(f"Missing 'messages' in {key}")
return fail_msg, [], []

msgs = last[key]["messages"]
if not msgs:
logging.error("Empty messages list in generation output")
return fail_msg, [], []

response = str(msgs[0])
sources = []
tools = []

if is_rag:
for item in output[1:-1]:
if isinstance(item, dict):
for k, v in item.items():
if k.startswith("retrieve_") and isinstance(v, dict):
urls = v.get("urls", [])
sources.extend(urls)
else:
if "agent" in output[0] and "tools" in output[0]["agent"]:
tools = output[0]["agent"]["tools"]
for i in range(len(tools)):
if i + 1 < len(output):
tool_out = list(output[i + 1].values())[0]
urls = tool_out.get("urls", [])
sources.extend(urls)

# Deduplicate sources
return response, list(set(sources)), tools


def show_response(text: str, sources: list[str], tools: list[str]) -> None:
console.print(
Panel(Markdown(text), title="[bold cyan]Assistant", border_style="cyan")
)

if tools:
console.print(f"[yellow]Tools:[/yellow] {', '.join(tools)}")

if sources:
src_table = Table(title="Sources", show_header=False, border_style="dim")
src_table.add_column("URL", style="blue")
for src in sources:
src_table.add_row(src)
console.print(src_table)

console.print()


def main() -> None:
console.clear()
console.print(
Panel("[bold green]ORAssistant Chatbot[/bold green]", border_style="green")
)

cuda = str(os.getenv("USE_CUDA")).lower() == "true"
fast = str(os.getenv("FAST_MODE")).lower() == "true"
debug = str(os.getenv("DEBUG")).lower() == "true"
mcp = str(os.getenv("ENABLE_MCP")).lower() == "true"
use_db = str(os.getenv("USE_DB", "true")).lower() == "true"

llm = setup_llm()
embed_cfg = setup_embeddings()
reranker = str(os.getenv("HF_RERANKER"))

with console.status("[bold green]Initializing graph...", spinner="dots"):
graph = RetrieverGraph(
llm_model=llm,
embeddings_config=embed_cfg,
reranking_model_name=reranker,
use_cuda=cuda,
inbuilt_tool_calling=True,
fast_mode=fast,
debug=debug,
enable_mcp=mcp,
)
graph.initialize()

if graph.graph is None:
console.print("[bold red]Failed to initialize graph[/bold red]")
sys.exit(1)

db = None
db_generator = None
conv_id = None
local_history: list[dict[str, str]] = []

if use_db:
if init_database():
db_generator = get_db()
db = next(db_generator)
conv_id = uuid4()
crud.create_conversation(db, conversation_uuid=conv_id, title=None)
console.print("[dim]Database: enabled[/dim]")
else:
logging.info(output)
print("LLM response extraction failed")
console.print("[yellow]Database unavailable, using local memory[/yellow]")
use_db = False
else:
console.print("[dim]Database: disabled[/dim]")

console.print("[dim]Type 'exit' or 'quit' to end session[/dim]\n")

try:
while True:
query = Prompt.ask("[bold blue]You[/bold blue]")

if query.lower() in {"exit", "quit", "q"}:
console.print("[yellow]Goodbye![/yellow]")
break

if not query.strip():
continue

if use_db and db and conv_id:
crud.create_message(
db=db,
conversation_uuid=conv_id,
role="user",
content=query,
)

inputs = {
"messages": [("user", query)],
"chat_history": get_history(db, conv_id, local_history),
}

with console.status("[bold green]Thinking...", spinner="dots"):
output = list(graph.graph.stream(inputs, stream_mode="updates"))

response, sources, tools = parse_output(output)

if use_db and db and conv_id:
ctx_srcs = {"sources": [{"source": s, "context": ""} for s in sources]}
crud.create_message(
db=db,
conversation_uuid=conv_id,
role="assistant",
content=response,
context_sources=ctx_srcs,
tools=tools,
)
else:
local_history.append({"User": query, "AI": response})

show_response(response, sources, tools)

except KeyboardInterrupt:
console.print("\n[yellow]Interrupted. Goodbye![/yellow]")
except ValueError as e:
console.print(f"[bold red]Configuration Error:[/bold red] {str(e)}")
console.print("[yellow]Check your environment variables and try again.[/yellow]")
if debug:
logging.exception("Configuration error")
except ConnectionError as e:
console.print(f"[bold red]Connection Error:[/bold red] {str(e)}")
console.print("[yellow]Check your network connection and database availability.[/yellow]")
if debug:
logging.exception("Connection error")
except Exception as e:
console.print(f"[bold red]Unexpected Error:[/bold red] {str(e)}")
console.print("[yellow]Please report this issue if it persists.[/yellow]")
logging.exception("Unexpected error in main loop")
finally:
# Clean up database session
if db_generator is not None:
try:
next(db_generator)
except StopIteration:
pass


if __name__ == "__main__":
main()
Loading
Loading