diff --git a/api/agents/__init__.py b/api/agents/__init__.py index efd63f4e..a15e120e 100644 --- a/api/agents/__init__.py +++ b/api/agents/__init__.py @@ -4,6 +4,7 @@ from .relevancy_agent import RelevancyAgent from .follow_up_agent import FollowUpAgent from .response_formatter_agent import ResponseFormatterAgent +from .healer_agent import HealerAgent from .utils import parse_response __all__ = [ @@ -11,5 +12,6 @@ "RelevancyAgent", "FollowUpAgent", "ResponseFormatterAgent", + "HealerAgent", "parse_response" ] diff --git a/api/agents/analysis_agent.py b/api/agents/analysis_agent.py index ccd7c98a..c7bccc8a 100644 --- a/api/agents/analysis_agent.py +++ b/api/agents/analysis_agent.py @@ -18,18 +18,29 @@ def get_analysis( # pylint: disable=too-many-arguments, too-many-positional-arg db_description: str, instructions: str | None = None, memory_context: str | None = None, + database_type: str | None = None, ) -> dict: """Get analysis of user query against database schema.""" formatted_schema = self._format_schema(combined_tables) + # Add system message with database type if not already present + if not self.messages or self.messages[0].get("role") != "system": + self.messages.insert(0, { + "role": "system", + "content": ( + f"You are a SQL expert. TARGET DATABASE: " + f"{database_type.upper() if database_type else 'UNKNOWN'}" + ) + }) + prompt = self._build_prompt( - user_query, formatted_schema, db_description, instructions, memory_context + user_query, formatted_schema, db_description, + instructions, memory_context, database_type ) self.messages.append({"role": "user", "content": prompt}) completion_result = completion( model=Config.COMPLETION_MODEL, messages=self.messages, temperature=0, - top_p=1, ) response = completion_result.choices[0].message.content @@ -158,7 +169,8 @@ def _format_foreign_keys(self, foreign_keys: dict) -> str: def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-arguments self, user_input: str, formatted_schema: str, - db_description: str, instructions, memory_context: str | None = None + db_description: str, instructions, memory_context: str | None = None, + database_type: str | None = None, ) -> str: """ Build the prompt for Claude to analyze the query. @@ -169,6 +181,7 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a db_description: Description of the database instructions: Custom instructions for the query memory_context: User and database memory context from previous interactions + database_type: Target database type (sqlite, postgresql, mysql, etc.) Returns: The formatted prompt for Claude @@ -196,6 +209,8 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a prompt = f""" You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score. + TARGET DATABASE: {database_type.upper() if database_type else 'UNKNOWN'} + MANDATORY RULES: - Always explain if you cannot fully follow the instructions. - Always reduce the confidence score if instructions cannot be fully applied. diff --git a/api/agents/healer_agent.py b/api/agents/healer_agent.py new file mode 100644 index 00000000..e0ab66a6 --- /dev/null +++ b/api/agents/healer_agent.py @@ -0,0 +1,328 @@ +""" +HealerAgent - Specialized agent for fixing SQL syntax errors. + +This agent focuses solely on correcting SQL queries that failed execution, +without requiring full graph context. It uses the error message and the +failed query to generate a corrected version. +""" +# pylint: disable=trailing-whitespace,line-too-long,too-many-arguments +# pylint: disable=too-many-positional-arguments,broad-exception-caught + +import re +from typing import Dict, Callable, Any +from litellm import completion +from api.config import Config +from .utils import parse_response + + +class HealerAgent: + """Agent specialized in fixing SQL syntax errors.""" + + def __init__(self, max_healing_attempts: int = 3): + """Initialize the healer agent. + + Args: + max_healing_attempts: Maximum number of healing attempts before giving up + """ + self.max_healing_attempts = max_healing_attempts + self.messages = [] + + @staticmethod + def validate_sql_syntax(sql_query: str) -> dict: + """ + Validate SQL query for basic syntax errors. + Similar to CypherValidator in the text-to-cypher PR. + + Args: + sql_query: The SQL query to validate + + Returns: + dict with 'is_valid', 'errors', and 'warnings' keys + """ + errors = [] + warnings = [] + + query = sql_query.strip() + + # Check if query is empty + if not query: + errors.append("Query is empty") + return {"is_valid": False, "errors": errors, "warnings": warnings} + + # Check for basic SQL keywords + query_upper = query.upper() + has_sql_keywords = any( + kw in query_upper for kw in ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "CREATE"] + ) + if not has_sql_keywords: + errors.append("Query does not contain valid SQL keywords") + + # Check for dangerous operations (for dev/test safety) + dangerous_patterns = [ + r'\bDROP\s+TABLE\b', r'\bTRUNCATE\b', r'\bDELETE\s+FROM\s+\w+\s*;?\s*$' + ] + for pattern in dangerous_patterns: + if re.search(pattern, query_upper): + warnings.append(f"Query contains potentially dangerous operation: {pattern}") + + # Check for balanced parentheses + paren_count = 0 + for char in query: + if char == '(': + paren_count += 1 + elif char == ')': + paren_count -= 1 + if paren_count < 0: + errors.append("Unbalanced parentheses in query") + break + if paren_count != 0: + errors.append("Unbalanced parentheses in query") + + # Check for SELECT queries have proper structure + if query_upper.startswith("SELECT") or "SELECT" in query_upper: + if "FROM" not in query_upper and "DUAL" not in query_upper: + warnings.append("SELECT query missing FROM clause") + + return { + "is_valid": len(errors) == 0, + "errors": errors, + "warnings": warnings + } + + def _build_healing_prompt( + self, + failed_sql: str, + error_message: str, + db_description: str, + question: str, + database_type: str + ) -> str: + """Build a focused prompt for SQL query healing.""" + + # Analyze error to provide targeted hints + error_hints = self._analyze_error(error_message, database_type) + + prompt = f"""You are a SQL query debugging expert. Your task is to fix a SQL query that failed execution. + +DATABASE TYPE: {database_type.upper()} + +FAILED SQL QUERY: +```sql +{failed_sql} +``` + +EXECUTION ERROR: +{error_message} + +{f"ORIGINAL QUESTION: {question}" if question else ""} + +{f"DATABASE INFO: {db_description}"} + +COMMON ERROR PATTERNS: +{error_hints} + +YOUR TASK: +1. Identify the exact cause of the error +2. Fix ONLY what's broken - don't rewrite the entire query +3. Ensure the fix is compatible with {database_type.upper()} +4. Maintain the original query logic and intent + +CRITICAL RULES FOR {database_type.upper()}: +""" + + if database_type == "sqlite": + prompt += """ +- SQLite does NOT support EXTRACT() function - use strftime() instead + * EXTRACT(YEAR FROM date_col) → strftime('%Y', date_col) + * EXTRACT(MONTH FROM date_col) → strftime('%m', date_col) + * EXTRACT(DAY FROM date_col) → strftime('%d', date_col) +- SQLite column/table names are case-insensitive BUT must exist +- SQLite uses double quotes "column" for identifiers with special characters +- Use backticks `column` for compatibility +- No schema qualifiers (database.table.column) +""" + elif database_type == "postgresql": + prompt += """ +- PostgreSQL is case-sensitive - use double quotes for mixed-case identifiers +- EXTRACT() is supported: EXTRACT(YEAR FROM date_col) +- Column references must match exact case when quoted +""" + + prompt += """ +RESPONSE FORMAT (valid JSON only): +{ + "sql_query": "-- your fixed SQL query here", + "confidence": 85, + "explanation": "Brief explanation of what was fixed", + "changes_made": ["Changed EXTRACT to strftime", "Fixed column casing"] +} + +IMPORTANT: +- Return ONLY the JSON object, no other text +- Fix ONLY the specific error, preserve the rest +- Test your fix mentally before responding +- If error is about a column/table name, check spelling carefully +""" + + return prompt + + def heal_and_execute( # pylint: disable=too-many-locals + self, + initial_sql: str, + initial_error: str, + execute_sql_func: Callable[[str], Any], + db_description: str = "", + question: str = "", + database_type: str = "sqlite" + ) -> Dict[str, Any]: + """Iteratively heal and execute SQL query until success or max attempts. + + This method creates a conversation loop between the healer and the database: + 1. Build initial prompt once with the failed SQL and error (including syntax validation) + 2. Loop: Call LLM → Parse healed SQL → Execute → Check if successful + 3. If successful, return results + 4. If failed and not last attempt, add error feedback and repeat + 5. If failed on last attempt, return failure + + Args: + initial_sql: The initial SQL query that failed + initial_error: The error message from the initial execution failure + execute_sql_func: Function that executes SQL and returns results or raises exception + db_description: Optional database description + question: Optional original question + database_type: Type of database (sqlite, postgresql, mysql, etc.) + + Returns: + Dict containing: + - success: Whether healing succeeded + - sql_query: Final SQL query (healed or original) + - query_results: Results from successful execution (if success=True) + - attempts: Number of healing attempts made + - final_error: Final error message (if success=False) + """ + self.messages = [] + + # Validate SQL syntax for additional error context + validation_result = self.validate_sql_syntax(initial_sql) + additional_context = "" + if validation_result["errors"]: + additional_context += f"\nSyntax errors: {', '.join(validation_result['errors'])}" + if validation_result["warnings"]: + additional_context += f"\nWarnings: {', '.join(validation_result['warnings'])}" + # Enhance error message with validation context + enhanced_error = initial_error + additional_context + + # Build initial prompt once before the loop + prompt = self._build_healing_prompt( + failed_sql=initial_sql, + error_message=enhanced_error, + db_description=db_description, + question=question, + database_type=database_type + ) + self.messages.append({"role": "user", "content": prompt}) + + for attempt in range(self.max_healing_attempts): + # Call LLM + response = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0.1, + max_tokens=2000 + ) + + content = response.choices[0].message.content + self.messages.append({"role": "assistant", "content": content}) + + # Parse response + result = parse_response(content) + healed_sql = result.get("sql_query", "") + + # Execute against database + error = None + try: + query_results = execute_sql_func(healed_sql) + except Exception as e: + error = str(e) + + # Check if it worked + if error is None: + # Success! + return { + "success": True, + "sql_query": healed_sql, + "query_results": query_results, + "attempts": attempt + 1, + "final_error": None + } + + # Failed - check if last attempt + if attempt >= self.max_healing_attempts - 1: + return { + "success": False, + "sql_query": healed_sql, + "query_results": None, + "attempts": attempt + 1, + "final_error": error + } + + # Not last attempt - add feedback and continue + feedback = f"""The healed query failed with error: + +```sql +{healed_sql} +``` + +ERROR: +{error} + +Please fix this error.""" + self.messages.append({"role": "user", "content": feedback}) + + # Fallback return + return { + "success": False, + "sql_query": initial_sql, + "query_results": None, + "attempts": self.max_healing_attempts, + "final_error": initial_error + } + + + def _analyze_error(self, error_message: str, database_type: str) -> str: + """Analyze error message and provide targeted hints.""" + + error_lower = error_message.lower() + hints = [] + + # Common SQLite errors + if database_type == "sqlite": + if "near \"from\"" in error_lower or "syntax error" in error_lower: + hints.append("⚠️ EXTRACT() is NOT supported in SQLite - use strftime() instead!") + hints.append(" Example: strftime('%Y', date_column) for year") + + if "no such column" in error_lower: + hints.append("⚠️ Column name doesn't exist - check spelling and case") + hints.append(" SQLite is case-insensitive but the column must exist") + + if "no such table" in error_lower: + hints.append("⚠️ Table name doesn't exist - check spelling") + + if "ambiguous column" in error_lower: + hints.append("⚠️ Ambiguous column - use table alias: table.column or alias.column") + + # PostgreSQL errors + elif database_type == "postgresql": + if "column" in error_lower and "does not exist" in error_lower: + hints.append("⚠️ Column case mismatch - PostgreSQL is case-sensitive") + hints.append(' Use double quotes for mixed-case: "ColumnName"') + + if "relation" in error_lower and "does not exist" in error_lower: + hints.append("⚠️ Table doesn't exist or case mismatch") + + # Generic hints if no specific patterns matched + if not hints: + hints.append("⚠️ Check syntax compatibility with " + database_type.upper()) + hints.append("⚠️ Verify column and table names exist") + + return "\n".join(hints) diff --git a/api/agents/utils.py b/api/agents/utils.py index 53e678a0..ceafff23 100644 --- a/api/agents/utils.py +++ b/api/agents/utils.py @@ -21,6 +21,7 @@ def __init__(self, queries_history: list, result_history: list): def parse_response(response: str) -> Dict[str, Any]: """ Parse Claude's response to extract the analysis. + Handles cases where LLM returns multiple JSON blocks by extracting the last valid one. Args: response: Claude's response string @@ -29,12 +30,38 @@ def parse_response(response: str) -> Dict[str, Any]: Parsed analysis results """ try: - # Extract JSON from the response + # Try to find all JSON blocks (anything between { and }) + # and parse the last valid one (LLM sometimes corrects itself) + # Find all potential JSON blocks + json_blocks = [] + depth = 0 + start_idx = None + + for i, char in enumerate(response): + if char == '{': + if depth == 0: + start_idx = i + depth += 1 + elif char == '}': + depth -= 1 + if depth == 0 and start_idx is not None: + json_blocks.append(response[start_idx:i+1]) + start_idx = None + + # Try to parse JSON blocks from last to first (prefer the corrected version) + for json_str in reversed(json_blocks): + try: + analysis = json.loads(json_str) + # Validate it has required fields + if "is_sql_translatable" in analysis and "sql_query" in analysis: + return analysis + except json.JSONDecodeError: + continue + + # Fallback to original method if block parsing fails json_start = response.find("{") json_end = response.rfind("}") + 1 json_str = response[json_start:json_end] - - # Parse the JSON analysis = json.loads(json_str) return analysis except (json.JSONDecodeError, ValueError) as e: diff --git a/api/core/text2sql.py b/api/core/text2sql.py index 4df9db08..9db90c4e 100644 --- a/api/core/text2sql.py +++ b/api/core/text2sql.py @@ -1,4 +1,5 @@ """Graph-related routes for the text2sql API.""" +# pylint: disable=line-too-long,trailing-whitespace import asyncio import json @@ -12,6 +13,7 @@ from api.core.errors import GraphNotFoundError, InternalError, InvalidArgumentError from api.core.schema_loader import load_database from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent, FollowUpAgent +from api.agents.healer_agent import HealerAgent from api.config import Config from api.extensions import db from api.graph import find, get_db_description @@ -252,7 +254,7 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m db_description, db_url = await get_db_description(graph_id) # Determine database type and get appropriate loader - _, loader_class = get_database_type_and_loader(db_url) + db_type, loader_class = get_database_type_and_loader(db_url) if not loader_class: overall_elapsed = time.perf_counter() - overall_start @@ -309,7 +311,8 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m logging.info("Starting SQL generation with analysis agent") answer_an = agent_an.get_analysis( - queries_history[-1], result, db_description, instructions, memory_context + queries_history[-1], result, db_description, instructions, memory_context, + db_type ) # Initialize response variables @@ -317,14 +320,27 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m follow_up_result = "" execution_error = False - # Auto-quote table names with special characters (like dashes) - original_sql = answer_an['sql_query'] - if original_sql: + logging.info("Generated SQL query: %s", answer_an['sql_query']) # nosemgrep + yield json.dumps( + { + "type": "sql_query", + "data": answer_an["sql_query"], + "conf": answer_an["confidence"], + "miss": answer_an["missing_information"], + "amb": answer_an["ambiguities"], + "exp": answer_an["explanation"], + "is_valid": answer_an["is_sql_translatable"], + "final_response": False, + } + ) + MESSAGE_DELIMITER + + # If the SQL query is valid, execute it using the configured database and db_url + if answer_an["is_sql_translatable"]: + # Auto-quote table names with special characters (like dashes) # Extract known table names from the result schema known_tables = {table[0] for table in result} if result else set() # Determine database type and get appropriate quote character - db_type, _ = get_database_type_and_loader(db_url) quote_char = DatabaseSpecificQuoter.get_quote_char( db_type or 'postgresql' ) @@ -332,7 +348,7 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m # Auto-quote identifiers with special characters sanitized_sql, was_modified = ( SQLIdentifierQuoter.auto_quote_identifiers( - original_sql, known_tables, quote_char + answer_an['sql_query'], known_tables, quote_char ) ) @@ -344,22 +360,6 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m logging.info(msg) answer_an['sql_query'] = sanitized_sql - logging.info("Generated SQL query: %s", answer_an['sql_query']) # nosemgrep - yield json.dumps( - { - "type": "sql_query", - "data": answer_an["sql_query"], - "conf": answer_an["confidence"], - "miss": answer_an["missing_information"], - "amb": answer_an["ambiguities"], - "exp": answer_an["explanation"], - "is_valid": answer_an["is_sql_translatable"], - "final_response": False, - } - ) + MESSAGE_DELIMITER - - # If the SQL query is valid, execute it using the postgres database db_url - if answer_an["is_sql_translatable"]: # Check if this is a destructive operation that requires confirmation sql_query = answer_an["sql_query"] sql_type = sql_query.strip().split()[0].upper() if sql_query else "" @@ -441,10 +441,76 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m loader_class.is_schema_modifying_query(sql_query) ) - query_results = loader_class.execute_sql_query( - answer_an["sql_query"], - db_url - ) + # Try executing the SQL query first + try: + query_results = loader_class.execute_sql_query( + answer_an["sql_query"], + db_url + ) + except Exception as exec_error: # pylint: disable=broad-exception-caught + # Initial execution failed - start iterative healing process + step = { + "type": "reasoning_step", + "final_response": False, + "message": "Step 2a: SQL execution failed, attempting to heal query..." + } + yield json.dumps(step) + MESSAGE_DELIMITER + + # Create healer agent and attempt iterative healing + healer_agent = HealerAgent(max_healing_attempts=3) + + # Create a wrapper function for execute_sql_query + def execute_sql(sql: str): + return loader_class.execute_sql_query(sql, db_url) + + healing_result = healer_agent.heal_and_execute( + initial_sql=answer_an["sql_query"], + initial_error=str(exec_error), + execute_sql_func=execute_sql, + db_description=db_description, + question=queries_history[-1], + database_type=db_type + ) + + if not healing_result.get("success"): + # Healing failed after all attempts + yield json.dumps({ + "type": "healing_failed", + "final_response": False, + "message": f"❌ Failed to heal query after {healing_result['attempts']} attempt(s)", + "final_error": healing_result.get("final_error", str(exec_error)), + "healing_log": healing_result.get("healing_log", []) + }) + MESSAGE_DELIMITER + raise exec_error + + # Healing succeeded! + healing_log = healing_result.get("healing_log", []) + + # Show healing progress + for log_entry in healing_log: + if log_entry.get("status") == "healed": + changes_msg = ", ".join(log_entry.get("changes_made", [])) + yield json.dumps({ + "type": "healing_attempt", + "final_response": False, + "message": f"Attempt {log_entry['attempt']}: {changes_msg}", + "attempt": log_entry["attempt"], + "changes": log_entry.get("changes_made", []), + "confidence": log_entry.get("confidence", 0) + }) + MESSAGE_DELIMITER + + # Update the SQL query to the healed version + answer_an["sql_query"] = healing_result["sql_query"] + query_results = healing_result["query_results"] + + yield json.dumps({ + "type": "healing_success", + "final_response": False, + "message": f"✅ Query healed and executed successfully after {healing_result['attempts'] + 1} attempt(s)", + "healed_sql": healing_result["sql_query"], + "attempts": healing_result["attempts"] + 1 + }) + MESSAGE_DELIMITER + if len(query_results) != 0: yield json.dumps( {