Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 143 additions & 72 deletions api/agents/analysis_agent.py

Large diffs are not rendered by default.

115 changes: 64 additions & 51 deletions api/core/text2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from api.agents.healer_agent import HealerAgent
from api.config import Config
from api.extensions import db
from api.graph import find, get_db_description
from api.graph import find, get_db_description, get_user_rules
from api.loaders.postgres_loader import PostgresLoader
from api.loaders.mysql_loader import MySQLLoader
from api.memory.graphiti_tool import MemoryTool
Expand Down Expand Up @@ -45,6 +45,8 @@ class ChatRequest(BaseModel):
chat: list[str]
result: list[str] | None = None
instructions: str | None = None
use_user_rules: bool = True # If True, fetch rules from database; if False, don't use rules
use_memory: bool = True


class ConfirmRequest(BaseModel):
Expand Down Expand Up @@ -213,6 +215,7 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
queries_history = chat_data.chat if hasattr(chat_data, 'chat') else None
result_history = chat_data.result if hasattr(chat_data, 'result') else None
instructions = chat_data.instructions if hasattr(chat_data, 'instructions') else None
use_user_rules = chat_data.use_user_rules if hasattr(chat_data, 'use_user_rules') else True

if not queries_history or not isinstance(queries_history, list):
raise InvalidArgumentError("Invalid or missing chat history")
Expand All @@ -233,7 +236,10 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):

logging.info("User Query: %s", sanitize_query(queries_history[-1]))

memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
if chat_data.use_memory:
memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
else:
memory_tool_task = None

# Create a generator function for streaming
async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
Expand All @@ -252,6 +258,8 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
yield json.dumps(step) + MESSAGE_DELIMITER
# Ensure the database description is loaded
db_description, db_url = await get_db_description(graph_id)
# Fetch user rules from database only if toggle is enabled
user_rules_spec = await get_user_rules(graph_id) if use_user_rules else None

# Determine database type and get appropriate loader
db_type, loader_class = get_database_type_and_loader(db_url)
Expand Down Expand Up @@ -304,15 +312,18 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m

logging.info("Calling to analysis agent with query: %s",
sanitize_query(queries_history[-1])) # nosemgrep
memory_tool = await memory_tool_task
memory_context = await memory_tool.search_memories(
query=queries_history[-1]
)

memory_context = None
if memory_tool_task:
memory_tool = await memory_tool_task
memory_context = await memory_tool.search_memories(
query=queries_history[-1]
)

logging.info("Starting SQL generation with analysis agent")
answer_an = agent_an.get_analysis(
queries_history[-1], result, db_description, instructions, memory_context,
db_type
db_type, user_rules_spec
)

# Initialize response variables
Expand Down Expand Up @@ -625,56 +636,58 @@ def execute_sql(sql: str):
)

# Save conversation to memory (only for on-topic queries)
# Determine the final answer based on which path was taken
final_answer = user_readable_response if user_readable_response else follow_up_result

# Build comprehensive response for memory
full_response = {
"question": queries_history[-1],
"generated_sql": answer_an.get('sql_query', ""),
"answer": final_answer
}
# Only save to memory if use_memory is enabled
if memory_tool_task:
# Determine the final answer based on which path was taken
final_answer = user_readable_response if user_readable_response else follow_up_result

# Build comprehensive response for memory
full_response = {
"question": queries_history[-1],
"generated_sql": answer_an.get('sql_query', ""),
"answer": final_answer
}

# Add error information if SQL execution failed
if execution_error:
full_response["error"] = execution_error
full_response["success"] = False
else:
full_response["success"] = True
# Add error information if SQL execution failed
if execution_error:
full_response["error"] = execution_error
full_response["success"] = False
else:
full_response["success"] = True


# Save query to memory
save_query_task = asyncio.create_task(
memory_tool.save_query_memory(
query=queries_history[-1],
sql_query=answer_an["sql_query"],
success=full_response["success"],
error=execution_error
# Save query to memory
save_query_task = asyncio.create_task(
memory_tool.save_query_memory(
query=queries_history[-1],
sql_query=answer_an["sql_query"],
success=full_response["success"],
error=execution_error
)
)
save_query_task.add_done_callback(
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Query memory saved successfully")
)
)
save_query_task.add_done_callback(
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Query memory saved successfully")
)

# Save conversation with memory tool (run in background)
save_task = asyncio.create_task(
memory_tool.add_new_memory(full_response,
[queries_history, result_history])
)
# Add error handling callback to prevent silent failures
save_task.add_done_callback(
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Conversation saved to memory tool")
)
logging.info("Conversation save task started in background")
# Save conversation with memory tool (run in background)
save_task = asyncio.create_task(
memory_tool.add_new_memory(full_response,
[queries_history, result_history])
)
# Add error handling callback to prevent silent failures
save_task.add_done_callback(
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Conversation saved to memory tool")
)
logging.info("Conversation save task started in background")

# Clean old memory in background (once per week cleanup)
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
clean_memory_task.add_done_callback(
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Memory cleanup completed successfully")
)
# Clean old memory in background (once per week cleanup)
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
clean_memory_task.add_done_callback(
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Memory cleanup completed successfully")
)

# Log timing summary at the end of processing
overall_elapsed = time.perf_counter() - overall_start
Expand Down
28 changes: 28 additions & 0 deletions api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ async def get_db_description(graph_id: str) -> tuple[str, str]:
return (query_result.result_set[0][0],
query_result.result_set[0][1]) # Return the first result's description


async def get_user_rules(graph_id: str) -> str:
"""Get the user rules from the graph."""
graph = db.select_graph(graph_id)
query_result = await graph.query(
"""
MATCH (d:Database)
RETURN d.user_rules
"""
)

if not query_result.result_set or not query_result.result_set[0][0]:
return ""

return query_result.result_set[0][0]


async def set_user_rules(graph_id: str, user_rules: str) -> None:
"""Set the user rules in the graph."""
graph = db.select_graph(graph_id)
await graph.query(
"""
MERGE (d:Database)
SET d.user_rules = $user_rules
""",
{"user_rules": user_rules}
)

async def _query_graph(
graph,
query: str,
Expand Down
43 changes: 43 additions & 0 deletions api/routes/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
get_schema,
query_database,
refresh_database_schema,
_graph_name,
)
from api.graph import get_user_rules, set_user_rules
from api.auth.user_management import token_required
from api.routes.tokens import UNAUTHORIZED_RESPONSE

Expand Down Expand Up @@ -225,3 +227,44 @@ async def delete_graph(request: Request, graph_id: str):
content={"error": "Failed to delete database"},
status_code=500
)


class UserRulesRequest(BaseModel):
"""User rules request model."""
user_rules: str


@graphs_router.get("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
@token_required
async def get_graph_user_rules(request: Request, graph_id: str):
"""Get user rules for the specified graph."""
try:
full_graph_id = _graph_name(request.state.user_id, graph_id)
user_rules = await get_user_rules(full_graph_id)
logging.info("Retrieved user rules length: %d", len(user_rules) if user_rules else 0)
return JSONResponse(content={"user_rules": user_rules})
except GraphNotFoundError:
return JSONResponse(content={"error": "Database not found"}, status_code=404)
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error getting user rules: %s", str(e))
return JSONResponse(content={"error": "Failed to get user rules"}, status_code=500)


@graphs_router.put("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
@token_required
async def update_graph_user_rules(request: Request, graph_id: str, data: UserRulesRequest):
"""Update user rules for the specified graph."""
try:
logging.info(
"Received request to update user rules, content length: %d", len(data.user_rules)
)
full_graph_id = _graph_name(request.state.user_id, graph_id)
await set_user_rules(full_graph_id, data.user_rules)
logging.info("User rules updated successfully")
return JSONResponse(content={"success": True, "user_rules": data.user_rules})
except GraphNotFoundError:
logging.error("Graph not found")
return JSONResponse(content={"error": "Database not found"}, status_code=404)
except Exception as e: # pylint: disable=broad-exception-caught
logging.error("Error updating user rules: %s", str(e))
return JSONResponse(content={"error": "Failed to update user rules"}, status_code=500)
2 changes: 1 addition & 1 deletion app/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "queryweaver-app",
"private": true,
"version": "0.0.1",
"version": "0.0.14",
"type": "module",
"scripts": {
"dev": "vite",
Expand Down
25 changes: 15 additions & 10 deletions app/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { BrowserRouter, Routes, Route } from "react-router-dom";
import { AuthProvider } from "@/contexts/AuthContext";
import { DatabaseProvider } from "@/contexts/DatabaseContext";
import { ChatProvider } from "@/contexts/ChatContext";
import Index from "./pages/Index";
import Settings from "./pages/Settings";
import NotFound from "./pages/NotFound";

const queryClient = new QueryClient();
Expand All @@ -13,16 +15,19 @@ const App = () => (
<QueryClientProvider client={queryClient}>
<AuthProvider>
<DatabaseProvider>
<TooltipProvider>
<Toaster />
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
<Routes>
<Route path="/" element={<Index />} />
{/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */}
<Route path="*" element={<NotFound />} />
</Routes>
</BrowserRouter>
</TooltipProvider>
<ChatProvider>
<TooltipProvider>
<Toaster />
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
<Routes>
<Route path="/" element={<Index />} />
<Route path="/settings" element={<Settings />} />
{/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */}
<Route path="*" element={<NotFound />} />
</Routes>
</BrowserRouter>
</TooltipProvider>
</ChatProvider>
</DatabaseProvider>
</AuthProvider>
</QueryClientProvider>
Expand Down
Loading