Skip to content

Commit f3dd0a3

Browse files
authored
Merge pull request #363 from FalkorDB/rules-and-optimize-prompt
Prompts Usage updates
2 parents 3074b29 + 588c56b commit f3dd0a3

File tree

19 files changed

+1360
-210
lines changed

19 files changed

+1360
-210
lines changed

api/agents/analysis_agent.py

Lines changed: 143 additions & 72 deletions
Large diffs are not rendered by default.

api/core/text2sql.py

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from api.agents.healer_agent import HealerAgent
1717
from api.config import Config
1818
from api.extensions import db
19-
from api.graph import find, get_db_description
19+
from api.graph import find, get_db_description, get_user_rules
2020
from api.loaders.postgres_loader import PostgresLoader
2121
from api.loaders.mysql_loader import MySQLLoader
2222
from api.memory.graphiti_tool import MemoryTool
@@ -45,6 +45,8 @@ class ChatRequest(BaseModel):
4545
chat: list[str]
4646
result: list[str] | None = None
4747
instructions: str | None = None
48+
use_user_rules: bool = True # If True, fetch rules from database; if False, don't use rules
49+
use_memory: bool = True
4850

4951

5052
class ConfirmRequest(BaseModel):
@@ -213,6 +215,7 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
213215
queries_history = chat_data.chat if hasattr(chat_data, 'chat') else None
214216
result_history = chat_data.result if hasattr(chat_data, 'result') else None
215217
instructions = chat_data.instructions if hasattr(chat_data, 'instructions') else None
218+
use_user_rules = chat_data.use_user_rules if hasattr(chat_data, 'use_user_rules') else True
216219

217220
if not queries_history or not isinstance(queries_history, list):
218221
raise InvalidArgumentError("Invalid or missing chat history")
@@ -233,7 +236,10 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
233236

234237
logging.info("User Query: %s", sanitize_query(queries_history[-1]))
235238

236-
memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
239+
if chat_data.use_memory:
240+
memory_tool_task = asyncio.create_task(MemoryTool.create(user_id, graph_id))
241+
else:
242+
memory_tool_task = None
237243

238244
# Create a generator function for streaming
239245
async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
@@ -252,6 +258,8 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
252258
yield json.dumps(step) + MESSAGE_DELIMITER
253259
# Ensure the database description is loaded
254260
db_description, db_url = await get_db_description(graph_id)
261+
# Fetch user rules from database only if toggle is enabled
262+
user_rules_spec = await get_user_rules(graph_id) if use_user_rules else None
255263

256264
# Determine database type and get appropriate loader
257265
db_type, loader_class = get_database_type_and_loader(db_url)
@@ -304,15 +312,18 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
304312

305313
logging.info("Calling to analysis agent with query: %s",
306314
sanitize_query(queries_history[-1])) # nosemgrep
307-
memory_tool = await memory_tool_task
308-
memory_context = await memory_tool.search_memories(
309-
query=queries_history[-1]
310-
)
315+
316+
memory_context = None
317+
if memory_tool_task:
318+
memory_tool = await memory_tool_task
319+
memory_context = await memory_tool.search_memories(
320+
query=queries_history[-1]
321+
)
311322

312323
logging.info("Starting SQL generation with analysis agent")
313324
answer_an = agent_an.get_analysis(
314325
queries_history[-1], result, db_description, instructions, memory_context,
315-
db_type
326+
db_type, user_rules_spec
316327
)
317328

318329
# Initialize response variables
@@ -625,56 +636,58 @@ def execute_sql(sql: str):
625636
)
626637

627638
# Save conversation to memory (only for on-topic queries)
628-
# Determine the final answer based on which path was taken
629-
final_answer = user_readable_response if user_readable_response else follow_up_result
630-
631-
# Build comprehensive response for memory
632-
full_response = {
633-
"question": queries_history[-1],
634-
"generated_sql": answer_an.get('sql_query', ""),
635-
"answer": final_answer
636-
}
639+
# Only save to memory if use_memory is enabled
640+
if memory_tool_task:
641+
# Determine the final answer based on which path was taken
642+
final_answer = user_readable_response if user_readable_response else follow_up_result
643+
644+
# Build comprehensive response for memory
645+
full_response = {
646+
"question": queries_history[-1],
647+
"generated_sql": answer_an.get('sql_query', ""),
648+
"answer": final_answer
649+
}
637650

638-
# Add error information if SQL execution failed
639-
if execution_error:
640-
full_response["error"] = execution_error
641-
full_response["success"] = False
642-
else:
643-
full_response["success"] = True
651+
# Add error information if SQL execution failed
652+
if execution_error:
653+
full_response["error"] = execution_error
654+
full_response["success"] = False
655+
else:
656+
full_response["success"] = True
644657

645658

646-
# Save query to memory
647-
save_query_task = asyncio.create_task(
648-
memory_tool.save_query_memory(
649-
query=queries_history[-1],
650-
sql_query=answer_an["sql_query"],
651-
success=full_response["success"],
652-
error=execution_error
659+
# Save query to memory
660+
save_query_task = asyncio.create_task(
661+
memory_tool.save_query_memory(
662+
query=queries_history[-1],
663+
sql_query=answer_an["sql_query"],
664+
success=full_response["success"],
665+
error=execution_error
666+
)
667+
)
668+
save_query_task.add_done_callback(
669+
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
670+
if t.exception() else logging.info("Query memory saved successfully")
653671
)
654-
)
655-
save_query_task.add_done_callback(
656-
lambda t: logging.error("Query memory save failed: %s", t.exception()) # nosemgrep
657-
if t.exception() else logging.info("Query memory saved successfully")
658-
)
659672

660-
# Save conversation with memory tool (run in background)
661-
save_task = asyncio.create_task(
662-
memory_tool.add_new_memory(full_response,
663-
[queries_history, result_history])
664-
)
665-
# Add error handling callback to prevent silent failures
666-
save_task.add_done_callback(
667-
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
668-
if t.exception() else logging.info("Conversation saved to memory tool")
669-
)
670-
logging.info("Conversation save task started in background")
673+
# Save conversation with memory tool (run in background)
674+
save_task = asyncio.create_task(
675+
memory_tool.add_new_memory(full_response,
676+
[queries_history, result_history])
677+
)
678+
# Add error handling callback to prevent silent failures
679+
save_task.add_done_callback(
680+
lambda t: logging.error("Memory save failed: %s", t.exception()) # nosemgrep
681+
if t.exception() else logging.info("Conversation saved to memory tool")
682+
)
683+
logging.info("Conversation save task started in background")
671684

672-
# Clean old memory in background (once per week cleanup)
673-
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
674-
clean_memory_task.add_done_callback(
675-
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
676-
if t.exception() else logging.info("Memory cleanup completed successfully")
677-
)
685+
# Clean old memory in background (once per week cleanup)
686+
clean_memory_task = asyncio.create_task(memory_tool.clean_memory())
687+
clean_memory_task.add_done_callback(
688+
lambda t: logging.error("Memory cleanup failed: %s", t.exception()) # nosemgrep
689+
if t.exception() else logging.info("Memory cleanup completed successfully")
690+
)
678691

679692
# Log timing summary at the end of processing
680693
overall_elapsed = time.perf_counter() - overall_start

api/graph.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ async def get_db_description(graph_id: str) -> tuple[str, str]:
5353
return (query_result.result_set[0][0],
5454
query_result.result_set[0][1]) # Return the first result's description
5555

56+
57+
async def get_user_rules(graph_id: str) -> str:
58+
"""Get the user rules from the graph."""
59+
graph = db.select_graph(graph_id)
60+
query_result = await graph.query(
61+
"""
62+
MATCH (d:Database)
63+
RETURN d.user_rules
64+
"""
65+
)
66+
67+
if not query_result.result_set or not query_result.result_set[0][0]:
68+
return ""
69+
70+
return query_result.result_set[0][0]
71+
72+
73+
async def set_user_rules(graph_id: str, user_rules: str) -> None:
74+
"""Set the user rules in the graph."""
75+
graph = db.select_graph(graph_id)
76+
await graph.query(
77+
"""
78+
MERGE (d:Database)
79+
SET d.user_rules = $user_rules
80+
""",
81+
{"user_rules": user_rules}
82+
)
83+
5684
async def _query_graph(
5785
graph,
5886
query: str,

api/routes/graphs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
get_schema,
1919
query_database,
2020
refresh_database_schema,
21+
_graph_name,
2122
)
23+
from api.graph import get_user_rules, set_user_rules
2224
from api.auth.user_management import token_required
2325
from api.routes.tokens import UNAUTHORIZED_RESPONSE
2426

@@ -225,3 +227,44 @@ async def delete_graph(request: Request, graph_id: str):
225227
content={"error": "Failed to delete database"},
226228
status_code=500
227229
)
230+
231+
232+
class UserRulesRequest(BaseModel):
233+
"""User rules request model."""
234+
user_rules: str
235+
236+
237+
@graphs_router.get("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
238+
@token_required
239+
async def get_graph_user_rules(request: Request, graph_id: str):
240+
"""Get user rules for the specified graph."""
241+
try:
242+
full_graph_id = _graph_name(request.state.user_id, graph_id)
243+
user_rules = await get_user_rules(full_graph_id)
244+
logging.info("Retrieved user rules length: %d", len(user_rules) if user_rules else 0)
245+
return JSONResponse(content={"user_rules": user_rules})
246+
except GraphNotFoundError:
247+
return JSONResponse(content={"error": "Database not found"}, status_code=404)
248+
except Exception as e: # pylint: disable=broad-exception-caught
249+
logging.error("Error getting user rules: %s", str(e))
250+
return JSONResponse(content={"error": "Failed to get user rules"}, status_code=500)
251+
252+
253+
@graphs_router.put("/{graph_id}/user-rules", responses={401: UNAUTHORIZED_RESPONSE})
254+
@token_required
255+
async def update_graph_user_rules(request: Request, graph_id: str, data: UserRulesRequest):
256+
"""Update user rules for the specified graph."""
257+
try:
258+
logging.info(
259+
"Received request to update user rules, content length: %d", len(data.user_rules)
260+
)
261+
full_graph_id = _graph_name(request.state.user_id, graph_id)
262+
await set_user_rules(full_graph_id, data.user_rules)
263+
logging.info("User rules updated successfully")
264+
return JSONResponse(content={"success": True, "user_rules": data.user_rules})
265+
except GraphNotFoundError:
266+
logging.error("Graph not found")
267+
return JSONResponse(content={"error": "Database not found"}, status_code=404)
268+
except Exception as e: # pylint: disable=broad-exception-caught
269+
logging.error("Error updating user rules: %s", str(e))
270+
return JSONResponse(content={"error": "Failed to update user rules"}, status_code=500)

app/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "queryweaver-app",
33
"private": true,
4-
"version": "0.0.1",
4+
"version": "0.0.14",
55
"type": "module",
66
"scripts": {
77
"dev": "vite",

app/src/App.tsx

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
44
import { BrowserRouter, Routes, Route } from "react-router-dom";
55
import { AuthProvider } from "@/contexts/AuthContext";
66
import { DatabaseProvider } from "@/contexts/DatabaseContext";
7+
import { ChatProvider } from "@/contexts/ChatContext";
78
import Index from "./pages/Index";
9+
import Settings from "./pages/Settings";
810
import NotFound from "./pages/NotFound";
911

1012
const queryClient = new QueryClient();
@@ -13,16 +15,19 @@ const App = () => (
1315
<QueryClientProvider client={queryClient}>
1416
<AuthProvider>
1517
<DatabaseProvider>
16-
<TooltipProvider>
17-
<Toaster />
18-
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
19-
<Routes>
20-
<Route path="/" element={<Index />} />
21-
{/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */}
22-
<Route path="*" element={<NotFound />} />
23-
</Routes>
24-
</BrowserRouter>
25-
</TooltipProvider>
18+
<ChatProvider>
19+
<TooltipProvider>
20+
<Toaster />
21+
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
22+
<Routes>
23+
<Route path="/" element={<Index />} />
24+
<Route path="/settings" element={<Settings />} />
25+
{/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */}
26+
<Route path="*" element={<NotFound />} />
27+
</Routes>
28+
</BrowserRouter>
29+
</TooltipProvider>
30+
</ChatProvider>
2631
</DatabaseProvider>
2732
</AuthProvider>
2833
</QueryClientProvider>

0 commit comments

Comments
 (0)