Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.copy
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ DB_USER=postgres
DB_PASSWORD=postgres

OPENAI_API_KEY=your_actual_openai_api_key_here
GEMINI_API_KEY=your_actual_gemini_api_key_here
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,5 @@ Thumbs.db
# config/local.py
# uploads/
# media/

.venv311
8 changes: 6 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@
id="query-strategy",
options=[
{"label": "Schema-Based Querying", "value": "schema"},
{"label": "Basic Text-to-SQL", "value": "basic"},
{"label": "RAG (Retrieval-Augmented Generation)", "value": "rag"},
{"label": "Visualize", "value": "visualize"},
{"label": "RAG (Retrieval-Augmented Generation)", "value": "rag"},
{"label": "Multi-Table Join", "value": "multitablejoin"}
],
value="schema",
Expand Down Expand Up @@ -413,7 +413,11 @@ def update_chat(n_clicks, n_submit, input_value, chat_history, settings, connect

try:
# Create query engine
engine_config = {"openai_api_key": Config.OPENAI_API_KEY, "db_uri": SQLITE_DB_PATH}
engine_config = {
"OPENAI_API_KEY": Config.OPENAI_API_KEY,
"GEMINI_API_KEY": Config.GEMINI_API_KEY,
"db_uri": SQLITE_DB_PATH
}
query_engine = query_engine_factory.create_query_engine(strategy, engine_config)

# Create security guardrail if enabled
Expand Down
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Config:

# LLM Configuration
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY')

# App Configuration
Expand Down
154 changes: 154 additions & 0 deletions database/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, List
import openai
import google.generativeai as genai # For Gemini API
from openai import OpenAI
import logging
from datetime import datetime
Expand Down Expand Up @@ -496,6 +497,153 @@ def execute_query(self, sql_query: str) -> Tuple[bool, Any]:



class BasicTextToSQLEngine(QueryEngine):
"""Basic text-to-SQL using manual prompt construction with schema and few-shot examples"""

def __init__(self, gemini_api_key: str):
self.gemini_api_key = gemini_api_key
genai.configure(api_key=gemini_api_key) # Configure Gemini API
logger.info("Initialized BasicTextToSQLEngine with Gemini API")

def get_name(self) -> str:
return "Basic Text-to-SQL"

def generate_query(self, user_query: str, context: Dict[str, Any]) -> Tuple[bool, str]:
"""Generate SQL query using basic text-to-SQL with manual prompt construction"""
logger.info(f"Starting basic text-to-SQL generation for: '{user_query}'")

try:
# Get database schema information
if not db_connection.is_connected():
return False, "Not connected to database"

# Get tables and their schemas
tables_success, tables_result = db_connection.get_tables()
if not tables_success:
return False, f"Failed to get tables: {tables_result}"

if tables_result.empty:
return False, "No tables found in the database"

# Build the manual prompt with schema and examples
prompt = self._build_manual_prompt(user_query, tables_result, context)

# Call Gemini API with the constructed prompt (free tier compatible)
model = genai.GenerativeModel('gemini-1.5-flash') # Free tier model
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=300, # Reduced for free tier
temperature=0.1,
candidate_count=1, # Free tier supports only 1 candidate
)
)
sql_query = response.text.strip()

# Fallback to OpenAI if Gemini fails (commented out)
# response = openai.ChatCompletion.create(
# model="gpt-3.5-turbo",
# messages=[
# {"role": "user", "content": prompt}
# ],
# max_tokens=500,
# temperature=0.1
# )
# sql_query = response.choices[0].message.content.strip()

# Clean up the response
if sql_query.startswith('```sql'):
sql_query = sql_query[6:]
if sql_query.endswith('```'):
sql_query = sql_query[:-3]
sql_query = sql_query.strip()

logger.info(f"Generated SQL: {sql_query}")
return True, sql_query

except Exception as e:
logger.error(f"Error in basic text-to-SQL generation: {str(e)}")
return False, f"Error generating SQL: {str(e)}"

def execute_query(self, sql_query: str) -> Tuple[bool, Any]:
"""Execute the generated SQL query"""
try:
return db_connection.execute_query(sql_query)
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
return False, f"Error executing query: {str(e)}"

def _build_manual_prompt(self, user_query: str, tables_df, context: Dict[str, Any]) -> str:
"""Build a comprehensive manual prompt with schema and few-shot examples"""
db_type = context.get('db_type', 'postgresql')

# Build full database schema
schema_info = self._build_full_schema(tables_df)

# Create the manual prompt
prompt = f"""You are a MySQL expert. Your role is to generate a valid SQL query based on the user's natural language question.

DATABASE SCHEMA:
{schema_info}

FEW-SHOT EXAMPLES:

Example 1:
Question: "Which customers in California spent the most last quarter?"
SQL: SELECT customer_name, SUM(amount) as total_spent FROM orders o JOIN customers c ON o.customer_id = c.id WHERE c.state = 'California' AND o.order_date >= DATE_SUB(NOW(), INTERVAL 3 MONTH) GROUP BY c.id, customer_name ORDER BY total_spent DESC LIMIT 10;

Example 2:
Question: "Show me all books published after 2020"
SQL: SELECT title, author, publication_year FROM books WHERE publication_year > 2020 ORDER BY publication_year DESC;

Example 3:
Question: "How many users have borrowed books in the last month?"
SQL: SELECT COUNT(DISTINCT user_id) as active_borrowers FROM book_loans WHERE loan_date >= DATE_SUB(NOW(), INTERVAL 1 MONTH);

INSTRUCTIONS:
1. Generate a valid {db_type} SQL query
2. Use ONLY the tables and columns shown in the schema above
3. Include appropriate WHERE clauses, JOINs, and ORDER BY as needed
4. Add LIMIT clause if the query might return many rows
5. Use proper SQL syntax and formatting
6. Return only the SQL query, no explanations

USER QUESTION: {user_query}

SQL QUERY:"""

return prompt

def _build_full_schema(self, tables_df) -> str:
"""Build complete database schema information"""
schema_parts = []

for _, table_row in tables_df.iterrows():
table_name = table_row['table_name']
schema_parts.append(f"\nTable: {table_name}")

# Get column information for this table
schema_success, schema_result = db_connection.get_table_schema(table_name)
if schema_success and not schema_result.empty:
schema_parts.append("Columns:")
for _, col_row in schema_result.iterrows():
col_name = col_row['column_name']
col_type = col_row['data_type']
is_nullable = col_row.get('is_nullable', 'YES')
col_default = col_row.get('column_default', '')

col_info = f" - {col_name} ({col_type})"
if is_nullable == 'NO':
col_info += " NOT NULL"
if col_default:
col_info += f" DEFAULT {col_default}"

schema_parts.append(col_info)
else:
schema_parts.append(" (Schema information not available)")

return "\n".join(schema_parts)

class RAGQueryEngine(QueryEngine):
"""RAG-based query generation (placeholder)"""

Expand Down Expand Up @@ -590,6 +738,12 @@ def create_query_engine(engine_type: str, config: Dict[str, Any]) -> QueryEngine
if not api_key:
raise ValueError("OpenAI API key required for schema-based querying")
return SchemaBasedQueryEngine(api_key)
elif engine_type == "basic":
# Using Gemini API for Basic Text-to-SQL
gemini_key = config.get('GEMINI_API_KEY')
if not gemini_key:
raise ValueError("Gemini API key required for basic text-to-SQL")
return BasicTextToSQLEngine(gemini_key)
elif engine_type == "rag":
return RAGQueryEngine()
elif engine_type == "multitablejoin":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ dash-bootstrap-components==1.5.0
dash-extensions==1.0.4
plotly==5.17.0
pandas==2.1.4
numpy==1.25.2
numpy>=1.26.4,<3
python-dotenv==1.0.0
requests==2.31.0
sqlalchemy==2.0.23
psycopg2-binary==2.9.9
pymysql==1.1.0
cryptography==41.0.7
google-generativeai==0.3.2
openai==1.102.0
flask==2.3.3
gunicorn==21.2.0
Expand Down