Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .relevancy_agent import RelevancyAgent
from .follow_up_agent import FollowUpAgent
from .response_formatter_agent import ResponseFormatterAgent
from .healer_agent import HealerAgent
from .utils import parse_response

__all__ = [
"AnalysisAgent",
"RelevancyAgent",
"FollowUpAgent",
"ResponseFormatterAgent",
"HealerAgent",
"parse_response"
]
25 changes: 22 additions & 3 deletions api/agents/analysis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@ def get_analysis( # pylint: disable=too-many-arguments, too-many-positional-arg
db_description: str,
instructions: str | None = None,
memory_context: str | None = None,
database_type: str | None = None,
) -> dict:
"""Get analysis of user query against database schema."""
formatted_schema = self._format_schema(combined_tables)
# Add system message with database type if not already present
if not self.messages or self.messages[0].get("role") != "system":
self.messages.insert(0, {
"role": "system",
"content": f"You are a SQL expert. TARGET DATABASE: {database_type.upper() if database_type else 'UNKNOWN'}"
})

prompt = self._build_prompt(
user_query, formatted_schema, db_description, instructions, memory_context
user_query, formatted_schema, db_description, instructions, memory_context, database_type
)
self.messages.append({"role": "user", "content": prompt})
completion_result = completion(
model=Config.COMPLETION_MODEL,
messages=self.messages,
temperature=0,
top_p=1,
)

response = completion_result.choices[0].message.content
Expand Down Expand Up @@ -158,7 +165,8 @@ def _format_foreign_keys(self, foreign_keys: dict) -> str:

def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-arguments
self, user_input: str, formatted_schema: str,
db_description: str, instructions, memory_context: str | None = None
db_description: str, instructions, memory_context: str | None = None,
database_type: str | None = None,
) -> str:
"""
Build the prompt for Claude to analyze the query.
Expand All @@ -169,6 +177,7 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a
db_description: Description of the database
instructions: Custom instructions for the query
memory_context: User and database memory context from previous interactions
database_type: Target database type (sqlite, postgresql, mysql, etc.)

Returns:
The formatted prompt for Claude
Expand Down Expand Up @@ -196,13 +205,19 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a
prompt = f"""
You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score.

TARGET DATABASE: {database_type.upper() if database_type else 'UNKNOWN'}

MANDATORY RULES:
- Always explain if you cannot fully follow the instructions.
- Always reduce the confidence score if instructions cannot be fully applied.
- Never skip explaining missing information, ambiguities, or instruction issues.
- Respond ONLY in strict JSON format, without extra text.
- If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far.
- CRITICAL: When table or column names contain special characters (especially dashes/hyphens like '-'), you MUST wrap them in double quotes for PostgreSQL (e.g., "table-name") or backticks for MySQL (e.g., `table-name`). This is NON-NEGOTIABLE.
- CRITICAL NULL HANDLING: When using calculated columns (divisions, ratios, arithmetic) with ORDER BY or LIMIT, you MUST filter out NULL values. Add "WHERE calculated_expression IS NOT NULL" or include the NULL check in your WHERE clause. NULL values sort first in ascending order and can produce incorrect results.
- CRITICAL SELECT CLAUSE: Only return columns explicitly requested in the question. If the question asks for "the highest rate" or "the lowest value", return ONLY that calculated value, not additional columns like names or IDs unless specifically asked. Use aggregate functions (MAX, MIN, AVG) when appropriate for "highest", "lowest", "average" queries instead of ORDER BY + LIMIT.
- CRITICAL VALUE MATCHING: When multiple columns could answer a question (e.g., "continuation schools"), prefer the column whose allowed values list contains an EXACT or CLOSEST string match to the question term. For example, if the question mentions "continuation schools", prefer a column with value "Continuation School" over "Continuation High Schools". Check the column descriptions for "Optional values" lists and match question terminology to those exact value strings.
- CRITICAL SINGLE SQL STATEMENT: You MUST generate exactly ONE SQL statement that answers all parts of the question. NEVER generate multiple separate SELECT statements. If a question asks multiple things (e.g., "How many X? List Y"), combine them into a single query using subqueries, JOINs, multiple columns in SELECT, or aggregate functions. Multiple SQL statements separated by semicolons are FORBIDDEN and will fail execution.

If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question.

Expand Down Expand Up @@ -299,6 +314,10 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a
12. Learn from successful query patterns in memory context and avoid failed approaches.
13. For personal queries, FIRST check memory context for user identification. If user identity is found in memory context (user name, previous personal queries, etc.), the query IS translatable.
14. CRITICAL PERSONALIZATION CHECK: If missing user identification/personalization is a significant or primary component of the query (e.g., "show my orders", "my account balance", "my recent purchases", "how many employees I have", "products I own") AND no user identification is available in memory context or schema, set "is_sql_translatable" to false. However, if memory context contains user identification (like user name or previous successful personal queries), then personal queries ARE translatable even if they are the primary component of the query.
15. CRITICAL: When generating queries with calculated columns (division, multiplication, etc.) that are used in ORDER BY or compared with LIMIT, ALWAYS add NULL filtering. For example: "WHERE (column1 / column2) IS NOT NULL" before ORDER BY. This prevents NULL values (from NULL numerators or denominators) from appearing in results.
16. SELECT CLAUSE PRECISION: Only include columns explicitly requested in the question. If a question asks "What is the highest rate?" return ONLY the rate value, not additional columns. Questions asking for "the highest/lowest/average X" should prefer aggregate functions (MAX, MIN, AVG) over ORDER BY + LIMIT, as aggregates are more concise and automatically handle what to return.
17. VALUE-BASED COLUMN SELECTION: When choosing between similar columns (e.g., "School Type" vs "Educational Option Type"), examine the "Optional values" lists in column descriptions. Prefer the column where a value EXACTLY or MOST CLOSELY matches the terminology in the question. For example, "continuation schools" should map to a column with value "Continuation School" rather than "Continuation High Schools". This string matching takes priority over column name similarity.
18. NULL HANDLING IN CALCULATIONS: When a query involves calculated expressions (like col1/col2) used with ORDER BY, filtering (WHERE), or LIMIT, ensure NULL values are explicitly filtered out. Use "AND (expression) IS NOT NULL" in the WHERE clause. This is especially important for division operations where either numerator or denominator can be NULL.

Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ # pylint: disable=line-too-long
return prompt
288 changes: 288 additions & 0 deletions api/agents/healer_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""
HealerAgent - Specialized agent for fixing SQL syntax errors.
This agent focuses solely on correcting SQL queries that failed execution,
without requiring full graph context. It uses the error message and the
failed query to generate a corrected version.
"""

import json
import re
from typing import Dict, Optional
from litellm import completion
from .utils import parse_response
from api.config import Config



class HealerAgent:
"""Agent specialized in fixing SQL syntax errors."""

def __init__(self):
"""
Initialize the HealerAgent.
"""

@staticmethod
def validate_sql_syntax(sql_query: str) -> dict:
"""
Validate SQL query for basic syntax errors.
Similar to CypherValidator in the text-to-cypher PR.
Args:
sql_query: The SQL query to validate
Returns:
dict with 'is_valid', 'errors', and 'warnings' keys
"""
errors = []
warnings = []

query = sql_query.strip()

# Check if query is empty
if not query:
errors.append("Query is empty")
return {"is_valid": False, "errors": errors, "warnings": warnings}

# Check for basic SQL keywords
query_upper = query.upper()
has_sql_keywords = any(
kw in query_upper for kw in ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "CREATE"]
)
if not has_sql_keywords:
errors.append("Query does not contain valid SQL keywords")

# Check for dangerous operations (for dev/test safety)
dangerous_patterns = [
r'\bDROP\s+TABLE\b', r'\bTRUNCATE\b', r'\bDELETE\s+FROM\s+\w+\s*;?\s*$'
]
for pattern in dangerous_patterns:
if re.search(pattern, query_upper):
warnings.append(f"Query contains potentially dangerous operation: {pattern}")

# Check for balanced parentheses
paren_count = 0
for char in query:
if char == '(':
paren_count += 1
elif char == ')':
paren_count -= 1
if paren_count < 0:
errors.append("Unbalanced parentheses in query")
break
if paren_count != 0:
errors.append("Unbalanced parentheses in query")

# Check for SELECT queries have proper structure
if query_upper.startswith("SELECT") or "SELECT" in query_upper:
if "FROM" not in query_upper and "DUAL" not in query_upper:
warnings.append("SELECT query missing FROM clause")

return {
"is_valid": len(errors) == 0,
"errors": errors,
"warnings": warnings
}

def heal_query(
self,
failed_sql: str,
error_message: str,
db_description: str = "",
question: str = "",
database_type: str = "sqlite"
) -> Dict[str, any]:
"""
Attempt to fix a failed SQL query using only the error message.
Args:
failed_sql: The SQL query that failed
error_message: The error message from execution
db_description: Optional database description
question: Optional original question
database_type: Type of database (sqlite, postgresql, mysql, etc.)
Returns:
Dict containing:
- sql_query: Fixed SQL query
- confidence: Confidence score
- explanation: Explanation of the fix
- changes_made: List of changes applied
"""
# Validate SQL syntax for additional error context
validation_result = self.validate_sql_syntax(failed_sql)
additional_context = ""
if validation_result["errors"]:
additional_context += f"\nSyntax errors: {', '.join(validation_result['errors'])}"
if validation_result["warnings"]:
additional_context += f"\nWarnings: {', '.join(validation_result['warnings'])}"

# Enhance error message with validation context
enhanced_error = error_message + additional_context

# Build focused prompt for SQL healing
prompt = self._build_healing_prompt(
failed_sql=failed_sql,
error_message=enhanced_error,
db_description=db_description,
question=question,
database_type=database_type
)

try:
# Call LLM for healing
response = completion(
model=Config.COMPLETION_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0.1, # Low temperature for precision
max_tokens=2000
)

content = response.choices[0].message.content

# Parse the response
result = parse_response(content)

# Validate the result has required fields
if not result.get("sql_query"):
return {
"sql_query": failed_sql, # Return original if healing failed
"confidence": 0.0,
"explanation": "Failed to parse healed SQL from response",
"changes_made": [],
"healing_failed": True
}

return {
"sql_query": result.get("sql_query", ""),
"confidence": result.get("confidence", 50),
"explanation": result.get("explanation", ""),
"changes_made": result.get("changes_made", []),
"healing_failed": False
}

except Exception as e:
return {
"sql_query": failed_sql, # Return original on error
"confidence": 0.0,
"explanation": f"Healing error: {str(e)}",
"changes_made": [],
"healing_failed": True
}

def _build_healing_prompt(
self,
failed_sql: str,
error_message: str,
db_description: str,
question: str,
database_type: str
) -> str:
"""Build a focused prompt for SQL query healing."""

# Analyze error to provide targeted hints
error_hints = self._analyze_error(error_message, database_type)

prompt = f"""You are a SQL query debugging expert. Your task is to fix a SQL query that failed execution.
DATABASE TYPE: {database_type.upper()}
FAILED SQL QUERY:
```sql
{failed_sql}
```
EXECUTION ERROR:
{error_message}
{f"ORIGINAL QUESTION: {question}" if question else ""}
{f"DATABASE INFO: {db_description[:500]}" if db_description else ""}
COMMON ERROR PATTERNS:
{error_hints}
YOUR TASK:
1. Identify the exact cause of the error
2. Fix ONLY what's broken - don't rewrite the entire query
3. Ensure the fix is compatible with {database_type.upper()}
4. Maintain the original query logic and intent
CRITICAL RULES FOR {database_type.upper()}:
"""

if database_type == "sqlite":
prompt += """
- SQLite does NOT support EXTRACT() function - use strftime() instead
* EXTRACT(YEAR FROM date_col) → strftime('%Y', date_col)
* EXTRACT(MONTH FROM date_col) → strftime('%m', date_col)
* EXTRACT(DAY FROM date_col) → strftime('%d', date_col)
- SQLite column/table names are case-insensitive BUT must exist
- SQLite uses double quotes "column" for identifiers with special characters
- Use backticks `column` for compatibility
- No schema qualifiers (database.table.column)
"""
elif database_type == "postgresql":
prompt += """
- PostgreSQL is case-sensitive - use double quotes for mixed-case identifiers
- EXTRACT() is supported: EXTRACT(YEAR FROM date_col)
- Column references must match exact case when quoted
"""

prompt += """
RESPONSE FORMAT (valid JSON only):
{
"sql_query": "-- your fixed SQL query here",
"confidence": 85,
"explanation": "Brief explanation of what was fixed",
"changes_made": ["Changed EXTRACT to strftime", "Fixed column casing"]
}
IMPORTANT:
- Return ONLY the JSON object, no other text
- Fix ONLY the specific error, preserve the rest
- Test your fix mentally before responding
- If error is about a column/table name, check spelling carefully
"""

return prompt

def _analyze_error(self, error_message: str, database_type: str) -> str:
"""Analyze error message and provide targeted hints."""

error_lower = error_message.lower()
hints = []

# Common SQLite errors
if database_type == "sqlite":
if "near \"from\"" in error_lower or "syntax error" in error_lower:
hints.append("⚠️ EXTRACT() is NOT supported in SQLite - use strftime() instead!")
hints.append(" Example: strftime('%Y', date_column) for year")

if "no such column" in error_lower:
hints.append("⚠️ Column name doesn't exist - check spelling and case")
hints.append(" SQLite is case-insensitive but the column must exist")

if "no such table" in error_lower:
hints.append("⚠️ Table name doesn't exist - check spelling")

if "ambiguous column" in error_lower:
hints.append("⚠️ Ambiguous column - use table alias: table.column or alias.column")

# PostgreSQL errors
elif database_type == "postgresql":
if "column" in error_lower and "does not exist" in error_lower:
hints.append("⚠️ Column case mismatch - PostgreSQL is case-sensitive")
hints.append(' Use double quotes for mixed-case: "ColumnName"')

if "relation" in error_lower and "does not exist" in error_lower:
hints.append("⚠️ Table doesn't exist or case mismatch")

# Generic hints if no specific patterns matched
if not hints:
hints.append("⚠️ Check syntax compatibility with " + database_type.upper())
hints.append("⚠️ Verify column and table names exist")

return "\n".join(hints)
Loading
Loading