Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 140 additions & 71 deletions api/agents/analysis_agent.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async def handle_oauth_error(

# Serve React app for all non-API routes (SPA catch-all)
@app.get("/{full_path:path}", include_in_schema=False)
async def serve_react_app(_full_path: str):
async def serve_react_app(full_path: str):
"""Serve the React app for all routes not handled by API endpoints."""
# Serve index.html for the React SPA
index_path = os.path.join(dist_path, "index.html")
Expand Down
118 changes: 66 additions & 52 deletions api/core/text2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from api.agents.healer_agent import HealerAgent
from api.config import Config
from api.extensions import db
from api.graph import find, get_db_description
from api.graph import find, get_db_description, get_user_rules
from api.loaders.postgres_loader import PostgresLoader
from api.loaders.mysql_loader import MySQLLoader
from api.memory.graphiti_tool import MemoryTool
Expand Down Expand Up @@ -45,6 +45,8 @@ class ChatRequest(BaseModel):
chat: list[str]
result: list[str] | None = None
instructions: str | None = None
use_user_rules: bool = True # If True, fetch rules from database; if False, don't use rules
use_memory: bool = True


class ConfirmRequest(BaseModel):
Expand All @@ -56,6 +58,7 @@ class ConfirmRequest(BaseModel):
sql_query: str
confirmation: str = ""
chat: list = []
use_user_rules: bool = True # If True, fetch rules from database; if False, don't use rules


def get_database_type_and_loader(db_url: str):
Expand Down Expand Up @@ -213,6 +216,7 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
queries_history = chat_data.chat if hasattr(chat_data, 'chat') else None
result_history = chat_data.result if hasattr(chat_data, 'result') else None
instructions = chat_data.instructions if hasattr(chat_data, 'instructions') else None
use_user_rules = chat_data.use_user_rules if hasattr(chat_data, 'use_user_rules') else True

if not queries_history or not isinstance(queries_history, list):
raise InvalidArgumentError("Invalid or missing chat history")
Expand All @@ -233,7 +237,10 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):

logging.info("User Query: %s", sanitize_query(queries_history[-1]))

memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
if chat_data.use_memory:
memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
else:
memory_tool_task = None

# Create a generator function for streaming
async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
Expand All @@ -252,6 +259,8 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
yield json.dumps(step) + MESSAGE_DELIMITER
# Ensure the database description is loaded
db_description, db_url = await get_db_description(graph_id)
# Fetch user rules from database only if toggle is enabled
user_rules_spec = await get_user_rules(graph_id) if use_user_rules else None

# Determine database type and get appropriate loader
db_type, loader_class = get_database_type_and_loader(db_url)
Expand Down Expand Up @@ -304,15 +313,18 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m

logging.info("Calling to analysis agent with query: %s",
sanitize_query(queries_history[-1])) # nosemgrep
memory_tool = await memory_tool_task
memory_context = await memory_tool.search_memories(
query=queries_history[-1]
)

memory_context = None
if memory_tool_task:
memory_tool = await memory_tool_task
memory_context = await memory_tool.search_memories(
query=queries_history[-1]
)

logging.info("Starting SQL generation with analysis agent")
answer_an = agent_an.get_analysis(
queries_history[-1], result, db_description, instructions, memory_context,
db_type
db_type, user_rules_spec
)

# Initialize response variables
Expand Down Expand Up @@ -625,56 +637,58 @@ def execute_sql(sql: str):
)

# Save conversation to memory (only for on-topic queries)
# Determine the final answer based on which path was taken
final_answer = user_readable_response if user_readable_response else follow_up_result

# Build comprehensive response for memory
full_response = {
"question": queries_history[-1],
"generated_sql": answer_an.get('sql_query', ""),
"answer": final_answer
}
# Only save to memory if use_memory is enabled
if memory_tool_task:
# Determine the final answer based on which path was taken
final_answer = user_readable_response if user_readable_response else follow_up_result

# Build comprehensive response for memory
full_response = {
"question": queries_history[-1],
"generated_sql": answer_an.get('sql_query', ""),
"answer": final_answer
}

# Add error information if SQL execution failed
if execution_error:
full_response["error"] = execution_error
full_response["success"] = False
else:
full_response["success"] = True
# Add error information if SQL execution failed
if execution_error:
full_response["error"] = execution_error
full_response["success"] = False
else:
full_response["success"] = True


# Save query to memory
save_query_task = asyncio.create_task(
memory_tool.save_query_memory(
query=queries_history[-1],
sql_query=answer_an["sql_query"],
success=full_response["success"],
error=execution_error
# Save query to memory
save_query_task = asyncio.create_task(
memory_tool.save_query_memory(
query=queries_history[-1],
sql_query=answer_an["sql_query"],
success=full_response["success"],
error=execution_error
)
)
save_query_task.add_done_callback(
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Query memory saved successfully")
)
)
save_query_task.add_done_callback(
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Query memory saved successfully")
)

# Save conversation with memory tool (run in background)
save_task = asyncio.create_task(
memory_tool.add_new_memory(full_response,
[queries_history, result_history])
)
# Add error handling callback to prevent silent failures
save_task.add_done_callback(
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Conversation saved to memory tool")
)
logging.info("Conversation save task started in background")
# Save conversation with memory tool (run in background)
save_task = asyncio.create_task(
memory_tool.add_new_memory(full_response,
[queries_history, result_history])
)
# Add error handling callback to prevent silent failures
save_task.add_done_callback(
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Conversation saved to memory tool")
)
logging.info("Conversation save task started in background")

# Clean old memory in background (once per week cleanup)
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
clean_memory_task.add_done_callback(
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Memory cleanup completed successfully")
)
# Clean old memory in background (once per week cleanup)
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
clean_memory_task.add_done_callback(
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
if t.exception() else logging.info("Memory cleanup completed successfully")
)

# Log timing summary at the end of processing
overall_elapsed = time.perf_counter() - overall_start
Expand Down Expand Up @@ -714,7 +728,7 @@ async def generate_confirmation(): # pylint: disable=too-many-locals,too-many-s
if confirmation == "CONFIRM":
try:
db_description, db_url = await get_db_description(graph_id)

# Determine database type and get appropriate loader
_, loader_class = get_database_type_and_loader(db_url)

Expand Down
28 changes: 28 additions & 0 deletions api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ async def get_db_description(graph_id: str) -> tuple[str, str]:
return (query_result.result_set[0][0],
query_result.result_set[0][1]) # Return the first result's description


async def get_user_rules(graph_id: str) -> str:
"""Get the user rules from the graph."""
graph = db.select_graph(graph_id)
query_result = await graph.query(
"""
MATCH (d:Database)
RETURN d.user_rules
"""
)

if not query_result.result_set or not query_result.result_set[0][0]:
return ""

return query_result.result_set[0][0]


async def set_user_rules(graph_id: str, user_rules: str) -> None:
"""Set the user rules in the graph."""
graph = db.select_graph(graph_id)
await graph.query(
"""
MATCH (d:Database)
SET d.user_rules = $user_rules
""",
{"user_rules": user_rules}
)

async def _query_graph(
graph,
query: str,
Expand Down
41 changes: 41 additions & 0 deletions api/routes/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
query_database,
refresh_database_schema,
)
from api.graph import get_user_rules, set_user_rules
from api.auth.user_management import token_required
from api.routes.tokens import UNAUTHORIZED_RESPONSE

Expand Down Expand Up @@ -225,3 +226,43 @@
content={"error": "Failed to delete database"},
status_code=500
)


class UserRulesRequest(BaseModel):
"""User rules request model."""
user_rules: str


@graphs_router.get("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
@token_required
async def get_graph_user_rules(request: Request, graph_id: str):
"""Get user rules for the specified graph."""
try:
full_graph_id = f"{request.state.user_id}_{graph_id}"
user_rules = await get_user_rules(full_graph_id)
return JSONResponse(content={"user_rules": user_rules})
except GraphNotFoundError:
return JSONResponse(content={"error": "Database not found"}, status_code=404)
except Exception as e:
logging.error("Error getting user rules: %s", str(e))
return JSONResponse(content={"error": "Failed to get user rules"}, status_code=500)


@graphs_router.put("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
@token_required
async def update_graph_user_rules(request: Request, graph_id: str, data: UserRulesRequest):
"""Update user rules for the specified graph."""
try:
logging.info("Received request to update user rules for graph: %s", graph_id)
logging.info("User rules content length: %d", len(data.user_rules))
full_graph_id = f"{request.state.user_id}_{graph_id}"
logging.info("Full graph_id: %s", full_graph_id)
await set_user_rules(full_graph_id, data.user_rules)
logging.info("User rules updated successfully for graph: %s", graph_id)
return JSONResponse(content={"success": True, "user_rules": data.user_rules})
except GraphNotFoundError:
logging.error("Graph not found: %s", graph_id)
return JSONResponse(content={"error": "Database not found"}, status_code=404)
except Exception as e:
logging.error("Error updating user rules: %s", str(e))
return JSONResponse(content={"error": "Failed to update user rules"}, status_code=500)
2 changes: 2 additions & 0 deletions app/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { BrowserRouter, Routes, Route } from "react-router-dom";
import { AuthProvider } from "@/contexts/AuthContext";
import { DatabaseProvider } from "@/contexts/DatabaseContext";
import Index from "./pages/Index";
import Settings from "./pages/Settings";
import NotFound from "./pages/NotFound";

const queryClient = new QueryClient();
Expand All @@ -18,6 +19,7 @@ const App = () => (
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
<Routes>
<Route path="/" element={<Index />} />
<Route path="/settings" element={<Settings />} />
{/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */}
<Route path="*" element={<NotFound />} />
</Routes>
Expand Down
13 changes: 12 additions & 1 deletion app/src/components/chat/ChatInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,17 @@ export interface ChatInterfaceProps {
className?: string;
disabled?: boolean; // when true, block interactions
onProcessingChange?: (isProcessing: boolean) => void; // callback to notify parent of processing state
useMemory?: boolean; // Whether to use memory context
useRulesFromDatabase?: boolean; // Whether to use rules from database (backend fetches them)
}

const ChatInterface = ({ className, disabled = false, onProcessingChange }: ChatInterfaceProps) => {
const ChatInterface = ({
className,
disabled = false,
onProcessingChange,
useMemory = true,
useRulesFromDatabase = true
}: ChatInterfaceProps) => {
const { toast } = useToast();
const { selectedGraph } = useDatabase();
const [isProcessing, setIsProcessing] = useState(false);
Expand Down Expand Up @@ -168,6 +176,8 @@ const ChatInterface = ({ className, disabled = false, onProcessingChange }: Chat
query,
database: selectedGraph.id,
history: historySnapshot,
use_user_rules: useRulesFromDatabase, // Backend fetches from DB when true
use_memory: useMemory,
})) {

if (message.type === 'status' || message.type === 'reasoning' || message.type === 'reasoning_step') {
Expand Down Expand Up @@ -344,6 +354,7 @@ const ChatInterface = ({ className, disabled = false, onProcessingChange }: Chat
sql_query: confirmMessage.confirmationData.sqlQuery,
confirmation: 'CONFIRM',
chat: confirmMessage.confirmationData.chatHistory,
use_user_rules: useRulesFromDatabase, // Backend fetches from DB when true
}
)) {
if (message.type === 'status' || message.type === 'reasoning' || message.type === 'reasoning_step') {
Expand Down
Loading
Loading