1616from api .agents .healer_agent import HealerAgent
1717from api .config import Config
1818from api .extensions import db
19- from api .graph import find , get_db_description
19+ from api .graph import find , get_db_description , get_user_rules
2020from api .loaders .postgres_loader import PostgresLoader
2121from api .loaders .mysql_loader import MySQLLoader
2222from api .memory .graphiti_tool import MemoryTool
@@ -45,6 +45,8 @@ class ChatRequest(BaseModel):
4545 chat : list [str ]
4646 result : list [str ] | None = None
4747 instructions : str | None = None
48+ use_user_rules : bool = True # If True, fetch rules from database; if False, don't use rules
49+ use_memory : bool = True
4850
4951
5052class ConfirmRequest (BaseModel ):
@@ -213,6 +215,7 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
213215 queries_history = chat_data .chat if hasattr (chat_data , 'chat' ) else None
214216 result_history = chat_data .result if hasattr (chat_data , 'result' ) else None
215217 instructions = chat_data .instructions if hasattr (chat_data , 'instructions' ) else None
218+ use_user_rules = chat_data .use_user_rules if hasattr (chat_data , 'use_user_rules' ) else True
216219
217220 if not queries_history or not isinstance (queries_history , list ):
218221 raise InvalidArgumentError ("Invalid or missing chat history" )
@@ -233,7 +236,10 @@ async def query_database(user_id: str, graph_id: str, chat_data: ChatRequest):
233236
234237 logging .info ("User Query: %s" , sanitize_query (queries_history [- 1 ]))
235238
236- memory_tool_task = asyncio .create_task (MemoryTool .create (user_id , graph_id ))
239+ if chat_data .use_memory :
240+ memory_tool_task = asyncio .create_task (MemoryTool .create (user_id , graph_id ))
241+ else :
242+ memory_tool_task = None
237243
238244 # Create a generator function for streaming
239245 async def generate (): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
@@ -252,6 +258,8 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
252258 yield json .dumps (step ) + MESSAGE_DELIMITER
253259 # Ensure the database description is loaded
254260 db_description , db_url = await get_db_description (graph_id )
261+ # Fetch user rules from database only if toggle is enabled
262+ user_rules_spec = await get_user_rules (graph_id ) if use_user_rules else None
255263
256264 # Determine database type and get appropriate loader
257265 db_type , loader_class = get_database_type_and_loader (db_url )
@@ -304,15 +312,18 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
304312
305313 logging .info ("Calling to analysis agent with query: %s" ,
306314 sanitize_query (queries_history [- 1 ])) # nosemgrep
307- memory_tool = await memory_tool_task
308- memory_context = await memory_tool .search_memories (
309- query = queries_history [- 1 ]
310- )
315+
316+ memory_context = None
317+ if memory_tool_task :
318+ memory_tool = await memory_tool_task
319+ memory_context = await memory_tool .search_memories (
320+ query = queries_history [- 1 ]
321+ )
311322
312323 logging .info ("Starting SQL generation with analysis agent" )
313324 answer_an = agent_an .get_analysis (
314325 queries_history [- 1 ], result , db_description , instructions , memory_context ,
315- db_type
326+ db_type , user_rules_spec
316327 )
317328
318329 # Initialize response variables
@@ -625,56 +636,58 @@ def execute_sql(sql: str):
625636 )
626637
627638 # Save conversation to memory (only for on-topic queries)
628- # Determine the final answer based on which path was taken
629- final_answer = user_readable_response if user_readable_response else follow_up_result
630-
631- # Build comprehensive response for memory
632- full_response = {
633- "question" : queries_history [- 1 ],
634- "generated_sql" : answer_an .get ('sql_query' , "" ),
635- "answer" : final_answer
636- }
639+ # Only save to memory if use_memory is enabled
640+ if memory_tool_task :
641+ # Determine the final answer based on which path was taken
642+ final_answer = user_readable_response if user_readable_response else follow_up_result
643+
644+ # Build comprehensive response for memory
645+ full_response = {
646+ "question" : queries_history [- 1 ],
647+ "generated_sql" : answer_an .get ('sql_query' , "" ),
648+ "answer" : final_answer
649+ }
637650
638- # Add error information if SQL execution failed
639- if execution_error :
640- full_response ["error" ] = execution_error
641- full_response ["success" ] = False
642- else :
643- full_response ["success" ] = True
651+ # Add error information if SQL execution failed
652+ if execution_error :
653+ full_response ["error" ] = execution_error
654+ full_response ["success" ] = False
655+ else :
656+ full_response ["success" ] = True
644657
645658
646- # Save query to memory
647- save_query_task = asyncio .create_task (
648- memory_tool .save_query_memory (
649- query = queries_history [- 1 ],
650- sql_query = answer_an ["sql_query" ],
651- success = full_response ["success" ],
652- error = execution_error
659+ # Save query to memory
660+ save_query_task = asyncio .create_task (
661+ memory_tool .save_query_memory (
662+ query = queries_history [- 1 ],
663+ sql_query = answer_an ["sql_query" ],
664+ success = full_response ["success" ],
665+ error = execution_error
666+ )
667+ )
668+ save_query_task .add_done_callback (
669+ lambda t : logging .error ("Query memory save failed: %s" , t .exception ()) # nosemgrep
670+ if t .exception () else logging .info ("Query memory saved successfully" )
653671 )
654- )
655- save_query_task .add_done_callback (
656- lambda t : logging .error ("Query memory save failed: %s" , t .exception ()) # nosemgrep
657- if t .exception () else logging .info ("Query memory saved successfully" )
658- )
659672
660- # Save conversation with memory tool (run in background)
661- save_task = asyncio .create_task (
662- memory_tool .add_new_memory (full_response ,
663- [queries_history , result_history ])
664- )
665- # Add error handling callback to prevent silent failures
666- save_task .add_done_callback (
667- lambda t : logging .error ("Memory save failed: %s" , t .exception ()) # nosemgrep
668- if t .exception () else logging .info ("Conversation saved to memory tool" )
669- )
670- logging .info ("Conversation save task started in background" )
673+ # Save conversation with memory tool (run in background)
674+ save_task = asyncio .create_task (
675+ memory_tool .add_new_memory (full_response ,
676+ [queries_history , result_history ])
677+ )
678+ # Add error handling callback to prevent silent failures
679+ save_task .add_done_callback (
680+ lambda t : logging .error ("Memory save failed: %s" , t .exception ()) # nosemgrep
681+ if t .exception () else logging .info ("Conversation saved to memory tool" )
682+ )
683+ logging .info ("Conversation save task started in background" )
671684
672- # Clean old memory in background (once per week cleanup)
673- clean_memory_task = asyncio .create_task (memory_tool .clean_memory ())
674- clean_memory_task .add_done_callback (
675- lambda t : logging .error ("Memory cleanup failed: %s" , t .exception ()) # nosemgrep
676- if t .exception () else logging .info ("Memory cleanup completed successfully" )
677- )
685+ # Clean old memory in background (once per week cleanup)
686+ clean_memory_task = asyncio .create_task (memory_tool .clean_memory ())
687+ clean_memory_task .add_done_callback (
688+ lambda t : logging .error ("Memory cleanup failed: %s" , t .exception ()) # nosemgrep
689+ if t .exception () else logging .info ("Memory cleanup completed successfully" )
690+ )
678691
679692 # Log timing summary at the end of processing
680693 overall_elapsed = time .perf_counter () - overall_start
0 commit comments