Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions api/agents/healer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
without requiring full graph context. It uses the error message and the
failed query to generate a corrected version.
"""
# pylint: disable=trailing-whitespace,line-too-long,too-many-arguments
# pylint: disable=line-too-long,too-many-arguments
# pylint: disable=too-many-positional-arguments,broad-exception-caught

import re
Expand All @@ -18,53 +18,53 @@

class HealerAgent:
"""Agent specialized in fixing SQL syntax errors."""

def __init__(self, max_healing_attempts: int = 3):
"""Initialize the healer agent.

Args:
max_healing_attempts: Maximum number of healing attempts before giving up
"""
self.max_healing_attempts = max_healing_attempts

@staticmethod
def validate_sql_syntax(sql_query: str) -> dict:
"""
Validate SQL query for basic syntax errors.
Similar to CypherValidator in the text-to-cypher PR.

Args:
sql_query: The SQL query to validate

Returns:
dict with 'is_valid', 'errors', and 'warnings' keys
"""
errors = []
warnings = []

query = sql_query.strip()

# Check if query is empty
if not query:
errors.append("Query is empty")
return {"is_valid": False, "errors": errors, "warnings": warnings}

# Check for basic SQL keywords
query_upper = query.upper()
has_sql_keywords = any(
kw in query_upper for kw in ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "CREATE"]
)
if not has_sql_keywords:
errors.append("Query does not contain valid SQL keywords")

# Check for dangerous operations (for dev/test safety)
dangerous_patterns = [
r'\bDROP\s+TABLE\b', r'\bTRUNCATE\b', r'\bDELETE\s+FROM\s+\w+\s*;?\s*$'
]
for pattern in dangerous_patterns:
if re.search(pattern, query_upper):
warnings.append(f"Query contains potentially dangerous operation: {pattern}")

# Check for balanced parentheses
paren_count = 0
for char in query:
Expand All @@ -77,18 +77,18 @@ def validate_sql_syntax(sql_query: str) -> dict:
break
if paren_count != 0:
errors.append("Unbalanced parentheses in query")

# Check for SELECT queries have proper structure
if query_upper.startswith("SELECT") or "SELECT" in query_upper:
if "FROM" not in query_upper and "DUAL" not in query_upper:
warnings.append("SELECT query missing FROM clause")

return {
"is_valid": len(errors) == 0,
"errors": errors,
"warnings": warnings
}

def _build_healing_prompt(
self,
failed_sql: str,
Expand All @@ -98,10 +98,10 @@ def _build_healing_prompt(
database_type: str
) -> str:
"""Build a focused prompt for SQL query healing."""

# Analyze error to provide targeted hints
error_hints = self._analyze_error(error_message, database_type)

prompt = f"""You are a SQL query debugging expert. Your task is to fix a SQL query that failed execution.

DATABASE TYPE: {database_type.upper()}
Expand Down Expand Up @@ -129,7 +129,7 @@ def _build_healing_prompt(

CRITICAL RULES FOR {database_type.upper()}:
"""

if database_type == "sqlite":
prompt += """
- SQLite does NOT support EXTRACT() function - use strftime() instead
Expand All @@ -147,7 +147,7 @@ def _build_healing_prompt(
- EXTRACT() is supported: EXTRACT(YEAR FROM date_col)
- Column references must match exact case when quoted
"""

prompt += """
RESPONSE FORMAT (valid JSON only):
{
Expand All @@ -163,9 +163,9 @@ def _build_healing_prompt(
- Test your fix mentally before responding
- If error is about a column/table name, check spelling carefully
"""

return prompt

def heal_and_execute(
self,
initial_sql: str,
Expand All @@ -176,22 +176,22 @@ def heal_and_execute(
database_type: str = "sqlite"
) -> Dict[str, Any]:
"""Iteratively heal and execute SQL query until success or max attempts.

This method creates a conversation loop between the healer and the database:
1. Build initial prompt once with the failed SQL and error (including syntax validation)
2. Loop: Call LLM → Parse healed SQL → Execute → Check if successful
3. If successful, return results
4. If failed and not last attempt, add error feedback and repeat
5. If failed on last attempt, return failure

Args:
initial_sql: The initial SQL query that failed
initial_error: The error message from the initial execution failure
execute_sql_func: Function that executes SQL and returns results or raises exception
db_description: Optional database description
question: Optional original question
database_type: Type of database (sqlite, postgresql, mysql, etc.)

Returns:
Dict containing:
- success: Whether healing succeeded
Expand All @@ -201,7 +201,7 @@ def heal_and_execute(
- final_error: Final error message (if success=False)
"""
self.messages = []

# Validate SQL syntax for additional error context
validation_result = self.validate_sql_syntax(initial_sql)
additional_context = ""
Expand All @@ -211,7 +211,7 @@ def heal_and_execute(
additional_context += f"\nWarnings: {', '.join(validation_result['warnings'])}"
# Enhance error message with validation context
enhanced_error = initial_error + additional_context

# Build initial prompt once before the loop
prompt = self._build_healing_prompt(
failed_sql=initial_sql,
Expand All @@ -221,7 +221,7 @@ def heal_and_execute(
database_type=database_type
)
self.messages.append({"role": "user", "content": prompt})

for attempt in range(self.max_healing_attempts):
# Call LLM
response = completion(
Expand All @@ -230,21 +230,21 @@ def heal_and_execute(
temperature=0.1,
max_tokens=2000
)

content = response.choices[0].message.content
self.messages.append({"role": "assistant", "content": content})

# Parse response
result = parse_response(content)
healed_sql = result.get("sql_query", "")

# Execute against database
error = None
try:
query_results = execute_sql_func(healed_sql)
except Exception as e:
error = str(e)

# Check if it worked
if error is None:
# Success!
Expand All @@ -255,7 +255,7 @@ def heal_and_execute(
"attempts": attempt + 1,
"final_error": None
}

# Failed - check if last attempt
if attempt >= self.max_healing_attempts - 1:
return {
Expand All @@ -265,7 +265,7 @@ def heal_and_execute(
"attempts": attempt + 1,
"final_error": error
}

# Not last attempt - add feedback and continue
feedback = f"""The healed query failed with error:

Expand All @@ -278,42 +278,42 @@ def heal_and_execute(

Please fix this error."""
self.messages.append({"role": "user", "content": feedback})


def _analyze_error(self, error_message: str, database_type: str) -> str:
"""Analyze error message and provide targeted hints."""

error_lower = error_message.lower()
hints = []

# Common SQLite errors
if database_type == "sqlite":
if "near \"from\"" in error_lower or "syntax error" in error_lower:
hints.append("⚠️ EXTRACT() is NOT supported in SQLite - use strftime() instead!")
hints.append(" Example: strftime('%Y', date_column) for year")

if "no such column" in error_lower:
hints.append("⚠️ Column name doesn't exist - check spelling and case")
hints.append(" SQLite is case-insensitive but the column must exist")

if "no such table" in error_lower:
hints.append("⚠️ Table name doesn't exist - check spelling")

if "ambiguous column" in error_lower:
hints.append("⚠️ Ambiguous column - use table alias: table.column or alias.column")

# PostgreSQL errors
elif database_type == "postgresql":
if "column" in error_lower and "does not exist" in error_lower:
hints.append("⚠️ Column case mismatch - PostgreSQL is case-sensitive")
hints.append(' Use double quotes for mixed-case: "ColumnName"')

if "relation" in error_lower and "does not exist" in error_lower:
hints.append("⚠️ Table doesn't exist or case mismatch")

# Generic hints if no specific patterns matched
if not hints:
hints.append("⚠️ Check syntax compatibility with " + database_type.upper())
hints.append("⚠️ Verify column and table names exist")

return "\n".join(hints)
Loading