diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index d19e21b7..dae0e66e 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -25,6 +25,7 @@ permissions: jobs: dependency-review: + if: github.repository_visibility == 'public' runs-on: ubuntu-latest steps: - name: 'Checkout repository' diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 1659747f..3689095f 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -17,8 +17,6 @@ jobs: - name: Install Poetry uses: abatilo/actions-poetry@v3 - with: - poetry-version: "1.8.3" - name: Configure Poetry run: | diff --git a/Dockerfile b/Dockerfile index e69de29b..34b7c0cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -0,0 +1,34 @@ +# Use a single stage build with FalkorDB base image +FROM falkordb/falkordb:latest + +ENV PYTHONUNBUFFERED=1 \ + FALKORDB_HOST=localhost \ + FALKORDB_PORT=6379 + +USER root + +# Install Python and pip, netcat for wait loop in start.sh +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + netcat-openbsd \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN python3 -m pip install --no-cache-dir --break-system-packages -r requirements.txt + +# Copy application code +COPY . . + +# Copy and make start.sh executable +COPY start.sh /start.sh +RUN chmod +x /start.sh + +EXPOSE 5000 6379 3000 + + +# Use start.sh as entrypoint +ENTRYPOINT ["/start.sh"] diff --git a/README.md b/README.md index cc2c7ac1..d9c759a3 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,61 @@ # Text2SQL +Text2SQL is a web application that allows users to interact with databases using natural language queries, powered by AI and graph database technology. + +## Setup + +### Prerequisites + +- Python 3.8+ +- Poetry (for dependency management) +- FalkorDB instance (or Redis with FalkorDB module) + +### Installation + +1. Clone the repository +2. Install dependencies with Poetry: + ```bash + poetry install + ``` + +3. Set up environment variables by copying `.env.example` to `.env` and filling in your values: + ```bash + cp .env.example .env + ``` + +### OAuth Configuration + +This application supports authentication via Google and GitHub OAuth. You'll need to set up OAuth applications for both providers: + +#### Google OAuth Setup + +1. Go to [Google Cloud Console](https://console.developers.google.com/) +2. Create a new project or select an existing one +3. Enable the Google+ API +4. Go to "Credentials" and create an OAuth 2.0 Client ID +5. Add your domain to authorized origins (e.g., `http://localhost:5000`) +6. Add the callback URL: `http://localhost:5000/login/google/authorized` +7. Copy the Client ID and Client Secret to your `.env` file + +#### GitHub OAuth Setup + +1. Go to GitHub Settings → Developer settings → OAuth Apps +2. Click "New OAuth App" +3. Fill in the application details: + - Application name: Your app name + - Homepage URL: `http://localhost:5000` + - Authorization callback URL: `http://localhost:5000/login/github/authorized` +4. Copy the Client ID and Client Secret to your `.env` file + +### Running the Application + ```bash poetry run flask --app api.index run ``` +The application will be available at `http://localhost:5000`. + ## Introduction ![image](https://github.com/user-attachments/assets/8b1743a8-1d24-4cb7-89a8-a95f626e68d9) diff --git a/api/agents.py b/api/agents.py deleted file mode 100644 index 7082c0c8..00000000 --- a/api/agents.py +++ /dev/null @@ -1,376 +0,0 @@ -import json -from litellm import completion -from api.config import Config -from typing import List, Dict, Any - -class AnalysisAgent(): - def __init__(self, queries_history: list, result_history: list): - if result_history is None: - self.messages = [] - else: - self.messages = [] - for query, result in zip(queries_history[:-1], result_history): - self.messages.append({"role": "user", "content": query}) - self.messages.append({"role": "assistant", "content": result}) - - def get_analysis(self, user_query: str, combined_tables: list, db_description: str, instructions: str = None) -> dict: - formatted_schema = self._format_schema(combined_tables) - prompt = self._build_prompt(user_query, formatted_schema, db_description, instructions) - self.messages.append({"role": "user", "content": prompt}) - completion_result = completion(model=Config.COMPLETION_MODEL, - messages=self.messages, - temperature=0, - top_p=1, - ) - - response = completion_result.choices[0].message.content - analysis = _parse_response(response) - if isinstance(analysis['ambiguities'], list): - analysis['ambiguities'] = [item.replace('-', ' ') for item in analysis['ambiguities']] - analysis['ambiguities'] = "- " + "- ".join(analysis['ambiguities']) - if isinstance(analysis['missing_information'], list): - analysis['missing_information'] = [item.replace('-', ' ') for item in analysis['missing_information']] - analysis['missing_information'] = "- " + "- ".join(analysis['missing_information']) - self.messages.append({"role": "assistant", "content": analysis['sql_query']}) - return analysis - - def _format_schema(self, schema_data: List) -> str: - """ - Format the schema data into a readable format for the prompt. - - Args: - schema_data: Schema in the structure [...] - - Returns: - Formatted schema as a string - """ - formatted_schema = [] - - for table_info in schema_data: - table_name = table_info[0] - table_description = table_info[1] - foreign_keys = table_info[2] - columns = table_info[3] - - # Format table header - table_str = f"Table: {table_name} - {table_description}\n" - - # Format columns using the updated OrderedDict structure - for column in columns: - col_name = column.get("columnName", "") - col_type = column.get("dataType", None) - col_description = column.get("description", "") - col_key = column.get("keyType", None) - nullable = column.get("nullable", False) - - key_info = f", PRIMARY KEY" if col_key == "PRI" else f", FOREIGN KEY" if col_key == "FK" else "" - column_str = f" - {col_name} ({col_type},{key_info},{col_key},{nullable}): {col_description}" - table_str += column_str + "\n" - - # Format foreign keys - if isinstance(foreign_keys, dict) and foreign_keys: - table_str += " Foreign Keys:\n" - for fk_name, fk_info in foreign_keys.items(): - column = fk_info.get("column", "") - ref_table = fk_info.get("referenced_table", "") - ref_column = fk_info.get("referenced_column", "") - table_str += f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" - - formatted_schema.append(table_str) - - return "\n".join(formatted_schema) - - def _build_prompt(self, user_input: str, formatted_schema: str, db_description: str, instructions) -> str: - """ - Build the prompt for Claude to analyze the query. - - Args: - user_input: The natural language query from the user - formatted_schema: Formatted database schema - - Returns: - The formatted prompt for Claude - """ - prompt = f""" - You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score. - - MANDATORY RULES: - - Always explain if you cannot fully follow the instructions. - - Always reduce the confidence score if instructions cannot be fully applied. - - Never skip explaining missing information, ambiguities, or instruction issues. - - Respond ONLY in strict JSON format, without extra text. - - If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far. - - If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question. - - Your output JSON MUST contain all fields, even if empty (e.g., "missing_information": []). - - --- - - Now analyze the user query based on the provided inputs: - - - {db_description} - - - - {instructions} - - - - {formatted_schema} - - - - {self.messages} - - - - {user_input} - - - --- - - Your task: - - - Analyze the query's translatability into SQL according to the instructions. - - Apply the instructions explicitly. - - If you CANNOT apply instructions in the SQL, explain why under "instructions_comments", "explanation" and reduce your confidence. - - Penalize confidence appropriately if any part of the instructions is unmet. - - When there several tables that can be used to answer the question, you can combine them in a single SQL query. - - Provide your output ONLY in the following JSON structure: - - ```json - {{ - "is_sql_translatable": true or false, - "instructions_comments": "Comments about any part of the instructions, especially if they are unclear, impossible, or partially met", - "explanation": "Detailed explanation why the query can or cannot be translated, mentioning instructions explicitly and referencing conversation history if relevant", - "sql_query": "High-level SQL query (you must to applying instructions and use previous answers if the question is a continuation)", - "tables_used": ["list", "of", "tables", "used", "in", "the", "query", "with", "the", "relationships", "between", "them"], - "missing_information": ["list", "of", "missing", "information"], - "ambiguities": ["list", "of", "ambiguities"], - "confidence": integer between 0 and 100 - }} - - Evaluation Guidelines: - - 1. Verify if all requested information exists in the schema. - 2. Check if the query's intent is clear enough for SQL translation. - 3. Identify any ambiguities in the query or instructions. - 4. List missing information explicitly if applicable. - 5. Confirm if necessary joins are possible. - 6. Consider if complex calculations are feasible in SQL. - 7. Identify multiple interpretations if they exist. - 8. Strictly apply instructions; explain and penalize if not possible. - 9. If the question is a follow-up, resolve references using the conversation history and previous answers. - - Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ - return prompt - - -class RelevancyAgent(): - def __init__(self, queries_history: list, result_history: list): - if result_history is None: - self.messages = [] - else: - self.messages = [] - for query, result in zip(queries_history[:-1], result_history): - self.messages.append({"role": "user", "content": query}) - self.messages.append({"role": "assistant", "content": result}) - - def get_answer(self, user_question: str, database_desc: dict) -> dict: - self.messages.append({"role": "user", "content": RELEVANCY_PROMPT.format(QUESTION_PLACEHOLDER=user_question, DB_PLACEHOLDER=json.dumps(database_desc))}) - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=self.messages, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - self.messages.append({"role": "assistant", "content": answer}) - return _parse_response(answer) - - -RELEVANCY_PROMPT = """ -You are an expert assistant tasked with determining whether the user’s question aligns with a given database description and whether the question is appropriate. You receive two inputs: - -The user’s question: {QUESTION_PLACEHOLDER} -The database description: {DB_PLACEHOLDER} -Please follow these instructions: - -Understand the question in the context of the database. -• Ask yourself: “Does this question relate to the data or concepts described in the database description?” -• Common tables that can be found in most of the systems considered "On-topic" even if it not explict in the database description. -• Don't answer questions that related to yourself. -• Don't answer questions that related to personal information unless it related to data in the schemas. -• Questions about the user's (first person) defined as "personal" and is Off-topic. -• Questions about yourself defined as "personal" and is Off-topic. - -Determine if the question is: -• On-topic and appropriate: -– If so, provide a JSON response in the following format: -{{ -"status": "On-topic", -"reason": "Brief explanation of why it is on-topic and appropriate." -"suggestions": [] -}} - -• Off-topic: -– If the question does not align with the data or use cases implied by the schema, provide a JSON response: -{{ -"status": "Off-topic", -"reason": "Short reason explaining why it is off-topic.", -"suggestions": [ -"An alternative, high-level question about the schema..." -] -}} - -• Inappropriate: -– If the question is offensive, illegal, or otherwise violates content guidelines, provide a JSON response: -{{ -"status": "Inappropriate", -"reason": "Short reason why it is inappropriate.", -"suggestions": [ -"Suggested topics that would be more appropriate..." -] -}} - -Ensure your response is concise, polite, and helpful. -""" - - -class FollowUpAgent(): - def __init__(self): - pass - - def get_answer(self, user_question: str, conversation_hist: list, database_schema: dict) -> dict: - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=[ - { - "content": FOLLOW_UP_PROMPT.format(QUESTION=user_question, HISTORY=conversation_hist, SCHEMA=json.dumps(database_schema)), - "role": "user" - } - ], - response_format={"type": "json_object"}, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - return json.loads(answer) - - - -FOLLOW_UP_PROMPT = """You are an expert assistant that receives two inputs: - -1. The user’s question: {QUESTION} -2. The history of his questions: {HISTORY} -3. A detected database schema (all relevant tables, columns, and their descriptions): {SCHEMA} - -Your primary goal is to decide if the user’s questions can be addressed using the existing schema or if new or additional data is required. -Any thing that can be calculated from the provided tables is define the status Data-focused. -Please follow these steps: - -1. Understand the user’s question in the context of the provided schema. -• Determine whether the question directly relates to the tables, columns, or concepts in the schema or needed more information about the filtering. - -2. If the question relates to the existing schema: -• Provide a concise JSON response indicating: -{{ -"status": "Data-focused", -"reason": "Brief explanation why this question is answerable with the given schema." -"followUpQuestion": "" -}} -• If relevant, note any additional observations or suggested follow-up. - -3. If the question cannot be answered solely with the given schema or if there seems to be missing context: -• Ask clarifying questions to confirm the user’s intent or to gather any necessary information. -• Use a JSON format such as: -{{ -"status": "Needs more data", -"reason": "Reason why the current schema is insufficient.", -"followUpQuestion": "Single question to clarify user intent or additional data needed, can be a specific value..." - -}} - -4. Ensure your response is concise, polite, and helpful. When asking clarifying questions, be specific and guide the user toward providing the missing details so you can effectively address their query.""" - - - -class TaxonomyAgent(): - def __init__(self): - pass - - def get_answer(self, question: str, sql: str) -> str: - messages = [ - { - "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), - "role": "user" - } - ] - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=messages, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - return answer - - - -TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query provde a single clarification question to the user. -* For any SQL query that contain WHERE clause, provide a clarification question to the user about the generated value. -* Your question can contain more than one clarification related to WHERE clause. -* Please asked only about the clarifications that you need and not extand the answer. -* Please ask in a polite, humen, and concise manner. -* Do not meantion any tables or columns in your ouput!. -* If you dont need any clarification, please answer with "I don't need any clarification." -* The user didnt saw the SQL queryor the tables, so please understand this position and ask the clarification in that way he have the relevent information to answer. -* When you ask the user to confirm a value, please provide the value in your answer. -* Mention only question about values and dont mention the SQL query or the tables in your answer. - -Please create the clarification question step by step. - -Question: -{QUESTION} - -SQL: -{SQL} - -For example: -question: "How many diabetic patients are there?" -SQL: "SELECT COUNT(*) FROM patients WHERE disease_code = 'E11'" -Your output: "The diabitic desease code is E11? If not, please provide the correct diabitic desease code. - -The question to the user:" -""" - -def _parse_response(response: str) -> Dict[str, Any]: - """ - Parse Claude's response to extract the analysis. - - Args: - response: Claude's response string - - Returns: - Parsed analysis results - """ - try: - # Extract JSON from the response - json_start = response.find('{') - json_end = response.rfind('}') + 1 - json_str = response[json_start:json_end] - - # Parse the JSON - analysis = json.loads(json_str) - return analysis - except (json.JSONDecodeError, ValueError) as e: - # Fallback if JSON parsing fails - return { - "is_sql_translatable": False, - "confidence": 0, - "explanation": f"Failed to parse response: {str(e)}", - "error": str(response) - } \ No newline at end of file diff --git a/api/agents/README.md b/api/agents/README.md new file mode 100644 index 00000000..75f66bb7 --- /dev/null +++ b/api/agents/README.md @@ -0,0 +1,67 @@ +# Agents Module + +This module contains various AI agents for the text2sql application. Each agent is responsible for a specific task in the query processing pipeline. + +## Agents + +### AnalysisAgent (`analysis_agent.py`) +- **Purpose**: Analyzes user queries and generates database analysis +- **Key Method**: `get_analysis()` - Analyzes user queries against database schema +- **Features**: Schema formatting, prompt building, conversation history tracking + +### RelevancyAgent (`relevancy_agent.py`) +- **Purpose**: Determines if queries are relevant to the database schema +- **Key Method**: `get_answer()` - Assesses query relevancy against database description +- **Features**: Topic classification (On-topic, Off-topic, Inappropriate) + +### FollowUpAgent (`follow_up_agent.py`) +- **Purpose**: Handles follow-up questions and conversational context +- **Key Method**: `get_answer()` - Processes follow-up questions using conversation history +- **Features**: Context awareness, data availability assessment + +### TaxonomyAgent (`taxonomy_agent.py`) +- **Purpose**: Provides taxonomy classification and clarification for SQL queries +- **Key Method**: `get_answer()` - Generates clarification questions for SQL queries +- **Features**: WHERE clause analysis, user-friendly clarifications + +### ResponseFormatterAgent (`response_formatter_agent.py`) +- **Purpose**: Formats SQL query results into user-readable responses +- **Key Method**: `format_response()` - Converts raw SQL results to natural language +- **Features**: Result formatting, operation type detection, user-friendly explanations + +## Utilities + +### utils.py +- **parse_response()**: Shared utility function for parsing JSON responses from AI models +- Used by multiple agents for consistent response parsing + +## Usage + +```python +from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent + +# Initialize agents +analysis_agent = AnalysisAgent(queries_history, result_history) +relevancy_agent = RelevancyAgent(queries_history, result_history) +formatter_agent = ResponseFormatterAgent() + +# Use agents +analysis = analysis_agent.get_analysis(query, tables, db_description) +relevancy = relevancy_agent.get_answer(question, database_desc) +response = formatter_agent.format_response(query, sql, results, db_description) +``` + +## Architecture + +Each agent follows a consistent pattern: +1. **Initialization**: Set up with necessary context (history, configuration) +2. **Main Method**: Primary interface for the agent's functionality +3. **Helper Methods**: Private methods for internal processing +4. **Prompt Templates**: Stored as module-level constants for easy maintenance +5. **LLM Integration**: Uses litellm for AI model interactions + +This modular structure improves: +- **Maintainability**: Each agent is self-contained +- **Testability**: Agents can be tested independently +- **Reusability**: Agents can be used in different contexts +- **Scalability**: New agents can be added without affecting existing ones diff --git a/api/agents/__init__.py b/api/agents/__init__.py new file mode 100644 index 00000000..1d508cc4 --- /dev/null +++ b/api/agents/__init__.py @@ -0,0 +1,17 @@ +"""Agents package for text2sql application.""" + +from .analysis_agent import AnalysisAgent +from .relevancy_agent import RelevancyAgent +from .follow_up_agent import FollowUpAgent +from .taxonomy_agent import TaxonomyAgent +from .response_formatter_agent import ResponseFormatterAgent +from .utils import parse_response + +__all__ = [ + "AnalysisAgent", + "RelevancyAgent", + "FollowUpAgent", + "TaxonomyAgent", + "ResponseFormatterAgent", + "parse_response" +] diff --git a/api/agents/analysis_agent.py b/api/agents/analysis_agent.py new file mode 100644 index 00000000..52494d66 --- /dev/null +++ b/api/agents/analysis_agent.py @@ -0,0 +1,209 @@ +"""Analysis agent for analyzing user queries and generating database analysis.""" + +from typing import List +from litellm import completion +from api.config import Config +from .utils import parse_response + + +class AnalysisAgent: + """Agent for analyzing user queries and generating database analysis.""" + + def __init__(self, queries_history: list, result_history: list): + """Initialize the analysis agent with query and result history.""" + if result_history is None: + self.messages = [] + else: + self.messages = [] + for query, result in zip(queries_history[:-1], result_history): + self.messages.append({"role": "user", "content": query}) + self.messages.append({"role": "assistant", "content": result}) + + def get_analysis( + self, + user_query: str, + combined_tables: list, + db_description: str, + instructions: str = None, + ) -> dict: + """Get analysis of user query against database schema.""" + formatted_schema = self._format_schema(combined_tables) + prompt = self._build_prompt( + user_query, formatted_schema, db_description, instructions + ) + self.messages.append({"role": "user", "content": prompt}) + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0, + top_p=1, + ) + + response = completion_result.choices[0].message.content + analysis = parse_response(response) + if isinstance(analysis["ambiguities"], list): + analysis["ambiguities"] = [ + item.replace("-", " ") for item in analysis["ambiguities"] + ] + analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) + if isinstance(analysis["missing_information"], list): + analysis["missing_information"] = [ + item.replace("-", " ") for item in analysis["missing_information"] + ] + analysis["missing_information"] = "- " + "- ".join( + analysis["missing_information"] + ) + self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) + return analysis + + def _format_schema(self, schema_data: List) -> str: + """ + Format the schema data into a readable format for the prompt. + + Args: + schema_data: Schema in the structure [...] + + Returns: + Formatted schema as a string + """ + formatted_schema = [] + + for table_info in schema_data: + table_name = table_info[0] + table_description = table_info[1] + foreign_keys = table_info[2] + columns = table_info[3] + + # Format table header + table_str = f"Table: {table_name} - {table_description}\n" + + # Format columns using the updated OrderedDict structure + for column in columns: + col_name = column.get("columnName", "") + col_type = column.get("dataType", None) + col_description = column.get("description", "") + col_key = column.get("keyType", None) + nullable = column.get("nullable", False) + + key_info = ( + ", PRIMARY KEY" + if col_key == "PRI" + else ", FOREIGN KEY" if col_key == "FK" else "" + ) + column_str = (f" - {col_name} ({col_type},{key_info},{col_key}," + f"{nullable}): {col_description}") + table_str += column_str + "\n" + + # Format foreign keys + if isinstance(foreign_keys, dict) and foreign_keys: + table_str += " Foreign Keys:\n" + for fk_name, fk_info in foreign_keys.items(): + column = fk_info.get("column", "") + ref_table = fk_info.get("referenced_table", "") + ref_column = fk_info.get("referenced_column", "") + table_str += ( + f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" + ) + + formatted_schema.append(table_str) + + return "\n".join(formatted_schema) + + def _build_prompt( + self, user_input: str, formatted_schema: str, db_description: str, instructions + ) -> str: + """ + Build the prompt for Claude to analyze the query. + + Args: + user_input: The natural language query from the user + formatted_schema: Formatted database schema + + Returns: + The formatted prompt for Claude + """ + prompt = f""" + You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score. + + MANDATORY RULES: + - Always explain if you cannot fully follow the instructions. + - Always reduce the confidence score if instructions cannot be fully applied. + - Never skip explaining missing information, ambiguities, or instruction issues. + - Respond ONLY in strict JSON format, without extra text. + - If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far. + + If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question. + + Your output JSON MUST contain all fields, even if empty (e.g., "missing_information": []). + + --- + + Now analyze the user query based on the provided inputs: + + + {db_description} + + + + {instructions} + + + + {formatted_schema} + + + + {self.messages} + + + + {user_input} + + + --- + + Your task: + + - Analyze the query's translatability into SQL according to the instructions. + - Apply the instructions explicitly. + - If you CANNOT apply instructions in the SQL, explain why under + "instructions_comments", "explanation" and reduce your confidence. + - Penalize confidence appropriately if any part of the instructions is unmet. + - When there several tables that can be used to answer the question, + you can combine them in a single SQL query. + + Provide your output ONLY in the following JSON structure: + + ```json + {{ + "is_sql_translatable": true or false, + "instructions_comments": ("Comments about any part of the instructions, " + "especially if they are unclear, impossible, " + "or partially met"), + "explanation": ("Detailed explanation why the query can or cannot be " + "translated, mentioning instructions explicitly and " + "referencing conversation history if relevant"), + "sql_query": ("High-level SQL query (you must to applying instructions " + "and use previous answers if the question is a continuation)"), + "tables_used": ["list", "of", "tables", "used", "in", "the", "query", + "with", "the", "relationships", "between", "them"], + "missing_information": ["list", "of", "missing", "information"], + "ambiguities": ["list", "of", "ambiguities"], + "confidence": integer between 0 and 100 + }} + + Evaluation Guidelines: + + 1. Verify if all requested information exists in the schema. + 2. Check if the query's intent is clear enough for SQL translation. + 3. Identify any ambiguities in the query or instructions. + 4. List missing information explicitly if applicable. + 5. Confirm if necessary joins are possible. + 6. Consider if complex calculations are feasible in SQL. + 7. Identify multiple interpretations if they exist. + 8. Strictly apply instructions; explain and penalize if not possible. + 9. If the question is a follow-up, resolve references using the + conversation history and previous answers. + + Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ + return prompt diff --git a/api/agents/follow_up_agent.py b/api/agents/follow_up_agent.py new file mode 100644 index 00000000..21f40e90 --- /dev/null +++ b/api/agents/follow_up_agent.py @@ -0,0 +1,72 @@ +"""Follow-up agent for handling follow-up questions and conversational context.""" + +import json +from litellm import completion +from api.config import Config + + +FOLLOW_UP_PROMPT = """You are an expert assistant that receives two inputs: + +1. The user's question: {QUESTION} +2. The history of his questions: {HISTORY} +3. A detected database schema (all relevant tables, columns, and their descriptions): {SCHEMA} + +Your primary goal is to decide if the user's questions can be addressed using the existing schema or if new or additional data is required. +Any thing that can be calculated from the provided tables is define the status Data-focused. +Please follow these steps: + +1. Understand the user's question in the context of the provided schema. +• Determine whether the question directly relates to the tables, columns, or concepts in the schema or needed more information about the filtering. + +2. If the question relates to the existing schema: +• Provide a concise JSON response indicating: +{{ +"status": "Data-focused", +"reason": "Brief explanation why this question is answerable with the given schema." +"followUpQuestion": "" +}} +• If relevant, note any additional observations or suggested follow-up. + +3. If the question cannot be answered solely with the given schema or if there seems to be missing context: +• Ask clarifying questions to confirm the user's intent or to gather any necessary information. +• Use a JSON format such as: +{{ +"status": "Needs more data", +"reason": "Reason why the current schema is insufficient.", +"followUpQuestion": "Single question to clarify user intent or additional data needed, can be a specific value..." + +}} + +4. Ensure your response is concise, polite, and helpful. When asking clarifying + questions, be specific and guide the user toward providing the missing details + so you can effectively address their query.""" + + +class FollowUpAgent: + """Agent for handling follow-up questions and conversational context.""" + + def __init__(self): + """Initialize the follow-up agent.""" + + def get_answer( + self, user_question: str, conversation_hist: list, database_schema: dict + ) -> dict: + """Get answer for follow-up questions using conversation history.""" + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=[ + { + "content": FOLLOW_UP_PROMPT.format( + QUESTION=user_question, + HISTORY=conversation_hist, + SCHEMA=json.dumps(database_schema), + ), + "role": "user", + } + ], + response_format={"type": "json_object"}, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + return json.loads(answer) diff --git a/api/agents/relevancy_agent.py b/api/agents/relevancy_agent.py new file mode 100644 index 00000000..931d8ee4 --- /dev/null +++ b/api/agents/relevancy_agent.py @@ -0,0 +1,89 @@ +"""Relevancy agent for determining relevancy of queries to database schema.""" + +import json +from litellm import completion +from api.config import Config +from .utils import parse_response + + +RELEVANCY_PROMPT = """ +You are an expert assistant tasked with determining whether the user's question aligns with a given database description and whether the question is appropriate. You receive two inputs: + +The user's question: {QUESTION_PLACEHOLDER} +The database description: {DB_PLACEHOLDER} +Please follow these instructions: + +Understand the question in the context of the database. +• Ask yourself: "Does this question relate to the data or concepts described in the database description?" +• Common tables that can be found in most of the systems considered "On-topic" even if it not explict in the database description. +• Don't answer questions that related to yourself. +• Don't answer questions that related to personal information unless it related to data in the schemas. +• Questions about the user's (first person) defined as "personal" and is Off-topic. +• Questions about yourself defined as "personal" and is Off-topic. + +Determine if the question is: +• On-topic and appropriate: +– If so, provide a JSON response in the following format: +{{ +"status": "On-topic", +"reason": "Brief explanation of why it is on-topic and appropriate." +"suggestions": [] +}} + +• Off-topic: +– If the question does not align with the data or use cases implied by the schema, provide a JSON response: +{{ +"status": "Off-topic", +"reason": "Short reason explaining why it is off-topic.", +"suggestions": [ +"An alternative, high-level question about the schema..." +] +}} + +• Inappropriate: +– If the question is offensive, illegal, or otherwise violates content guidelines, provide a JSON response: +{{ +"status": "Inappropriate", +"reason": "Short reason why it is inappropriate.", +"suggestions": [ +"Suggested topics that would be more appropriate..." +] +}} + +Ensure your response is concise, polite, and helpful. +""" + + +class RelevancyAgent: + """Agent for determining relevancy of queries to database schema.""" + + def __init__(self, queries_history: list, result_history: list): + """Initialize the relevancy agent with query and result history.""" + if result_history is None: + self.messages = [] + else: + self.messages = [] + for query, result in zip(queries_history[:-1], result_history): + self.messages.append({"role": "user", "content": query}) + self.messages.append({"role": "assistant", "content": result}) + + def get_answer(self, user_question: str, database_desc: dict) -> dict: + """Get relevancy assessment for user question against database description.""" + self.messages.append( + { + "role": "user", + "content": RELEVANCY_PROMPT.format( + QUESTION_PLACEHOLDER=user_question, + DB_PLACEHOLDER=json.dumps(database_desc), + ), + } + ) + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + self.messages.append({"role": "assistant", "content": answer}) + return parse_response(answer) diff --git a/api/agents/response_formatter_agent.py b/api/agents/response_formatter_agent.py new file mode 100644 index 00000000..413a1558 --- /dev/null +++ b/api/agents/response_formatter_agent.py @@ -0,0 +1,133 @@ +"""Response formatter agent for generating user-readable responses from SQL query results.""" + +from typing import List, Dict +from litellm import completion +from api.config import Config + + +RESPONSE_FORMATTER_PROMPT = """ +You are an AI assistant that helps users understand database query results. Your task is to analyze the SQL query results and provide a clear, concise, and user-friendly explanation. + +**Context:** +Database Description: {DB_DESCRIPTION} + +**User's Original Question:** +{USER_QUERY} + +**SQL Query Executed:** +{SQL_QUERY} + +**Query Type:** {SQL_TYPE} + +**Query Results:** +{FORMATTED_RESULTS} + +**Instructions:** +1. Provide a clear, natural language answer to the user's question based on the query results +2. For SELECT queries: Focus on the key insights and findings from the data +3. For INSERT/UPDATE/DELETE queries: Confirm the operation was successful and mention how many records were affected +4. For other operations (CREATE, DROP, etc.): Confirm the operation was completed successfully +5. Use bullet points or numbered lists when presenting multiple items +6. Include relevant numbers, percentages, or trends if applicable +7. Be concise but comprehensive - avoid unnecessary technical jargon +8. If the results are empty, explain that no data was found matching the criteria +9. If there are many results, provide a summary with highlights +10. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding + +**Response Format:** +Provide a direct answer to the user's question in a conversational tone, as if you were explaining the findings to a colleague. +""" + + +class ResponseFormatterAgent: + """Agent for generating user-readable responses from SQL query results.""" + + def __init__(self): + """Initialize the response formatter agent.""" + pass + + def format_response(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str = "") -> str: + """ + Generate a user-readable response based on the SQL query results. + + Args: + user_query: The original user question + sql_query: The SQL query that was executed + query_results: The results from the SQL query execution + db_description: Description of the database context + + Returns: + A formatted, user-readable response string + """ + prompt = self._build_response_prompt(user_query, sql_query, query_results, db_description) + + messages = [{"role": "user", "content": prompt}] + + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=messages, + temperature=0.3, # Slightly higher temperature for more natural responses + top_p=1, + ) + + response = completion_result.choices[0].message.content + return response.strip() + + def _build_response_prompt(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str) -> str: + """Build the prompt for generating user-readable responses.""" + + # Format the query results for better readability + formatted_results = self._format_query_results(query_results) + + # Determine the type of SQL operation + sql_type = sql_query.strip().split()[0].upper() if sql_query else "UNKNOWN" + + prompt = RESPONSE_FORMATTER_PROMPT.format( + DB_DESCRIPTION=db_description if db_description else "Not provided", + USER_QUERY=user_query, + SQL_QUERY=sql_query, + SQL_TYPE=sql_type, + FORMATTED_RESULTS=formatted_results + ) + + return prompt + + def _format_query_results(self, query_results: List[Dict]) -> str: + """Format query results for inclusion in the prompt.""" + if not query_results: + return "No results found." + + if len(query_results) == 0: + return "No results found." + + # Check if this is an operation result (INSERT/UPDATE/DELETE) + if len(query_results) == 1 and "operation" in query_results[0]: + result = query_results[0] + operation = result.get("operation", "UNKNOWN") + affected_rows = result.get("affected_rows") + status = result.get("status", "unknown") + + if affected_rows is not None: + return f"Operation: {operation}, Status: {status}, Affected rows: {affected_rows}" + else: + return f"Operation: {operation}, Status: {status}" + + # Handle regular SELECT query results + # Limit the number of results shown in the prompt to avoid token limits + max_results_to_show = 50 + results_to_show = query_results[:max_results_to_show] + + formatted = [] + for i, result in enumerate(results_to_show, 1): + if isinstance(result, dict): + result_str = ", ".join([f"{k}: {v}" for k, v in result.items()]) + formatted.append(f"{i}. {result_str}") + else: + formatted.append(f"{i}. {result}") + + result_text = "\n".join(formatted) + + if len(query_results) > max_results_to_show: + result_text += f"\n... and {len(query_results) - max_results_to_show} more results" + + return result_text diff --git a/api/agents/taxonomy_agent.py b/api/agents/taxonomy_agent.py new file mode 100644 index 00000000..f3088a39 --- /dev/null +++ b/api/agents/taxonomy_agent.py @@ -0,0 +1,59 @@ +"""Taxonomy agent for taxonomy classification of questions and SQL queries.""" + +from litellm import completion +from api.config import Config + + +TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query \ +provde a single clarification question to the user. +* For any SQL query that contain WHERE clause, provide a clarification question to the user about the \ +generated value. +* Your question can contain more than one clarification related to WHERE clause. +* Please asked only about the clarifications that you need and not extand the answer. +* Please ask in a polite, humen, and concise manner. +* Do not meantion any tables or columns in your ouput!. +* If you dont need any clarification, please answer with "I don't need any clarification." +* The user didnt saw the SQL queryor the tables, so please understand this position and ask the \ +clarification in that way he have the relevent information to answer. +* When you ask the user to confirm a value, please provide the value in your answer. +* Mention only question about values and dont mention the SQL query or the tables in your answer. + +Please create the clarification question step by step. + +Question: +{QUESTION} + +SQL: +{SQL} + +For example: +question: "How many diabetic patients are there?" +SQL: "SELECT COUNT(*) FROM patients WHERE disease_code = 'E11'" +Your output: "The diabitic desease code is E11? If not, please provide the correct diabitic desease code. + +The question to the user:" +""" + + +class TaxonomyAgent: + """Agent for taxonomy classification of questions and SQL queries.""" + + def __init__(self): + """Initialize the taxonomy agent.""" + + def get_answer(self, question: str, sql: str) -> str: + """Get taxonomy classification for a question and SQL pair.""" + messages = [ + { + "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), + "role": "user", + } + ] + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=messages, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + return answer diff --git a/api/agents/utils.py b/api/agents/utils.py new file mode 100644 index 00000000..25fefd2a --- /dev/null +++ b/api/agents/utils.py @@ -0,0 +1,33 @@ +"""Utility functions for agents.""" + +import json +from typing import Any, Dict + + +def parse_response(response: str) -> Dict[str, Any]: + """ + Parse Claude's response to extract the analysis. + + Args: + response: Claude's response string + + Returns: + Parsed analysis results + """ + try: + # Extract JSON from the response + json_start = response.find("{") + json_end = response.rfind("}") + 1 + json_str = response[json_start:json_end] + + # Parse the JSON + analysis = json.loads(json_str) + return analysis + except (json.JSONDecodeError, ValueError) as e: + # Fallback if JSON parsing fails + return { + "is_sql_translatable": False, + "confidence": 0, + "explanation": f"Failed to parse response: {str(e)}", + "error": str(response), + } diff --git a/api/config.py b/api/config.py index 2642ea9d..386155f2 100644 --- a/api/config.py +++ b/api/config.py @@ -1,59 +1,58 @@ -""" +""" This module contains the configuration for the text2sql module. """ -import os -from typing import Union + import dataclasses +from typing import Union + from litellm import embedding -import boto3 -class EmbeddingsModel(): - - def __init__( - self, - model_name: str, - config: dict = None - ): +class EmbeddingsModel: + """Embeddings model wrapper for text embedding operations.""" + + def __init__(self, model_name: str, config: dict = None): self.model_name = model_name self.config = config - + def embed(self, text: Union[str, list]) -> list: """ Get the embeddings of the text - + Args: text (str|list): The text(s) to embed - + Returns: list: The embeddings of the text - + """ embeddings = embedding(model=self.model_name, input=text) embeddings = [embedding["embedding"] for embedding in embeddings.data] return embeddings - + def get_vector_size(self) -> int: """ Get the size of the vector - + Returns: int: The size of the vector - + """ - response = embedding(input = ["Hello World"], model=self.model_name) - size = len(response.data[0]['embedding']) + response = embedding(input=["Hello World"], model=self.model_name) + size = len(response.data[0]["embedding"]) return size @dataclasses.dataclass class Config: """ - Configuration class for the text2sql module. + Configuration class for the text2sql module. """ + SCHEMA_PATH = "api/schema_schema.json" EMBEDDING_MODEL_NAME = "azure/text-embedding-ada-002" COMPLETION_MODEL = "azure/gpt-4.1" + VALIDATOR_MODEL = "azure/gpt-4.1" TEMPERATURE = 0 # client = boto3.client('sts') # AWS_PROFILE = os.getenv("aws_profile_name") @@ -66,11 +65,7 @@ class Config: # config["aws_region_name"] = AWS_REGION # config["aws_profile_name"] = AWS_PROFILE - EMBEDDING_MODEL = EmbeddingsModel( - model_name=EMBEDDING_MODEL_NAME, - config=config - ) - + EMBEDDING_MODEL = EmbeddingsModel(model_name=EMBEDDING_MODEL_NAME, config=config) FIND_SYSTEM_PROMPT = """ You are an expert in analyzing natural language queries into SQL tables descriptions. @@ -139,4 +134,4 @@ class Config: * **User Query (Natural Language):** You will be given a user's current question or request in natural language. - """ \ No newline at end of file + """ diff --git a/api/constants.py b/api/constants.py index a1eacddf..93532b47 100644 --- a/api/constants.py +++ b/api/constants.py @@ -1,86 +1,174 @@ -EXAMPLES = {'crm_usecase': ["Which companies have generated the most revenue through closed deals, and how much revenue did they generate?", - "How many leads converted into deals over the last month", - "Which companies have open sales opportunities and active SLA agreements in place?", - "Which high-value sales opportunities (value > $50,000) have upcoming meetings scheduled, and what companies are they associated with?"], - 'ERP_system': [ - # "What is the total value of all purchase orders created in the last quarter?", - # "Which suppliers have the highest number of active purchase orders, and what is the total value of those orders?", - "What is the total order value for customer Almo Office?", - "Show the total amount of all orders placed on 11/24", - "What's the profit for order SO2400002?", - "List all confirmed orders form today with their final prices", - "How many items are in order SO2400002?", +"""Constants and benchmark data for the text2sql application.""" - # Product-Specific Questions - "What is the price of Office Chair (part 0001100)?", - "List all items with quantity greater than 3 units", - "Show me all products with price above $20", - "What's the total cost of all A4 Paper items ordered?", - "Which items have the highest profit margin?", - - # Financial Analysis Questions - "Calculate the total profit for this year", - "Show me orders with overall discount greater than 5%", - "What's the average profit percentage across all items?", - "List orders with final price exceeding $700", - "Show me items with profit margin above 50%", - - # Customer-Related Questions - "How many orders has customer 100038 placed?", - "What's the total purchase amount by Almo Office?", - "List all orders with their customer names and contact details", - "Show me customers with orders above $500", - "What's the average order value per customer?", - - # Inventory/Stock Questions - "Which items have zero quantity?", - "Show me all items with their crate types", - "List products with their packaging details", - "What's the total quantity ordered for each product?", - "Show me items with pending shipments" - ] - } +EXAMPLES = { + "crm_usecase": [ + ("Which companies have generated the most revenue through closed deals, " + "and how much revenue did they generate?"), + "How many leads converted into deals over the last month", + ("Which companies have open sales opportunities and active SLA agreements " + "in place?"), + ("Which high-value sales opportunities (value > $50,000) have upcoming meetings " + "scheduled, and what companies are they associated with?"), + ], + "ERP_system": [ + # ("What is the total value of all purchase orders created in the last " + # "quarter?"), + # ("Which suppliers have the highest number of active purchase orders, " + # "and what is the total value of those orders?"), + "What is the total order value for customer Almo Office?", + "Show the total amount of all orders placed on 11/24", + "What's the profit for order SO2400002?", + "List all confirmed orders form today with their final prices", + "How many items are in order SO2400002?", + # Product-Specific Questions + "What is the price of Office Chair (part 0001100)?", + "List all items with quantity greater than 3 units", + "Show me all products with price above $20", + "What's the total cost of all A4 Paper items ordered?", + "Which items have the highest profit margin?", + # Financial Analysis Questions + "Calculate the total profit for this year", + "Show me orders with overall discount greater than 5%", + "What's the average profit percentage across all items?", + "List orders with final price exceeding $700", + "Show me items with profit margin above 50%", + # Customer-Related Questions + "How many orders has customer 100038 placed?", + "What's the total purchase amount by Almo Office?", + "List all orders with their customer names and contact details", + "Show me customers with orders above $500", + "What's the average order value per customer?", + # Inventory/Stock Questions + "Which items have zero quantity?", + "Show me all items with their crate types", + "List products with their packaging details", + "What's the total quantity ordered for each product?", + "Show me items with pending shipments", + ], +} BENCHMARK = [ { - "question": "List all contacts who are associated with companies that have at least one active deal in the pipeline, and include the deal stage.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, d.deal_name, ds.stage_name FROM contacts AS c JOIN company_contacts AS cc ON c.contact_id = cc.contact_id JOIN companies AS co ON cc.company_id = co.company_id JOIN deals AS d ON co.company_id = d.company_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.is_active = 1;" + "question": ("List all contacts who are associated with companies that have at " + "least one active deal in the pipeline, and include the deal stage."), + "sql": ("SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, " + "d.deal_name, ds.stage_name FROM contacts AS c " + "JOIN company_contacts AS cc ON c.contact_id = cc.contact_id " + "JOIN companies AS co ON cc.company_id = co.company_id " + "JOIN deals AS d ON co.company_id = d.company_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.is_active = 1;"), }, { - "question": "Which sales representatives (users) have closed deals worth more than $100,000 in the past year, and what was the total value of deals they closed?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS total_closed_value FROM users AS u JOIN deals AS d ON u.user_id = d.owner_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' AND d.close_date >= DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id HAVING total_closed_value > 100000;" + "question": ("Which sales representatives (users) have closed deals worth more " + "than $100,000 in the past year, and what was the total value of " + "deals they closed?"), + "sql": ("SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS " + "total_closed_value FROM users AS u " + "JOIN deals AS d ON u.user_id = d.owner_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' AND d.close_date >= " + "DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id " + "HAVING total_closed_value > 100000;"), }, { - "question": "Find all contacts who attended at least one event and were later converted into leads that became opportunities within three months of the event.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN event_attendees AS ea ON c.contact_id = ea.contact_id JOIN events AS e ON ea.event_id = e.event_id JOIN leads AS l ON c.contact_id = l.contact_id JOIN opportunities AS o ON l.lead_id = o.lead_id WHERE o.created_date BETWEEN e.event_date AND DATE_ADD(e.event_date, INTERVAL 3 MONTH);" + "question": ("Find all contacts who attended at least one event and were later " + "converted into leads that became opportunities within three months " + "of the event."), + "sql": ("SELECT DISTINCT c.contact_id, c.first_name, c.last_name " + "FROM contacts AS c " + "JOIN event_attendees AS ea ON c.contact_id = ea.contact_id " + "JOIN events AS e ON ea.event_id = e.event_id " + "JOIN leads AS l ON c.contact_id = l.contact_id " + "JOIN opportunities AS o ON l.lead_id = o.lead_id " + "WHERE o.created_date BETWEEN e.event_date AND " + "DATE_ADD(e.event_date, INTERVAL 3 MONTH);"), }, { - "question": "Which customers have the highest lifetime value based on their total invoice payments, including refunds and discounts?", - "sql": "SELECT c.contact_id, c.first_name, c.last_name, SUM(i.total_amount - COALESCE(r.refund_amount, 0) - COALESCE(d.discount_amount, 0)) AS lifetime_value FROM contacts AS c JOIN orders AS o ON c.contact_id = o.contact_id JOIN invoices AS i ON o.order_id = i.order_id LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;" + "question": ("Which customers have the highest lifetime value based on their " + "total invoice payments, including refunds and discounts?"), + "sql": ("SELECT c.contact_id, c.first_name, c.last_name, " + "SUM(i.total_amount - COALESCE(r.refund_amount, 0) - " + "COALESCE(d.discount_amount, 0)) AS lifetime_value " + "FROM contacts AS c " + "JOIN orders AS o ON c.contact_id = o.contact_id " + "JOIN invoices AS i ON o.order_id = i.order_id " + "LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id " + "LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id " + "GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;"), }, { - "question": "Show all deals that have involved at least one email exchange, one meeting, and one phone call with a contact in the past six months.", - "sql": "SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d JOIN contacts AS c ON d.contact_id = c.contact_id JOIN emails AS e ON c.contact_id = e.contact_id JOIN meetings AS m ON c.contact_id = m.contact_id JOIN phone_calls AS p ON c.contact_id = p.contact_id WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);" + "question": ("Show all deals that have involved at least one email exchange, " + "one meeting, and one phone call with a contact in the past six months."), + "sql": ("SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d " + "JOIN contacts AS c ON d.contact_id = c.contact_id " + "JOIN emails AS e ON c.contact_id = e.contact_id " + "JOIN meetings AS m ON c.contact_id = m.contact_id " + "JOIN phone_calls AS p ON c.contact_id = p.contact_id " + "WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) " + "AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) " + "AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);"), }, { - "question": "Which companies have the highest number of active support tickets, and how does their number of tickets correlate with their total deal value?", - "sql": "SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, SUM(d.amount) AS total_deal_value FROM companies AS co LEFT JOIN support_tickets AS st ON co.company_id = st.company_id AND st.status = 'Open' LEFT JOIN deals AS d ON co.company_id = d.company_id GROUP BY co.company_id ORDER BY active_tickets DESC;" + "question": ("Which companies have the highest number of active support tickets, " + "and how does their number of tickets correlate with their total deal value?"), + "sql": ("SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, " + "SUM(d.amount) AS total_deal_value FROM companies AS co " + "LEFT JOIN support_tickets AS st ON co.company_id = st.company_id " + "AND st.status = 'Open' " + "LEFT JOIN deals AS d ON co.company_id = d.company_id " + "GROUP BY co.company_id ORDER BY active_tickets DESC;"), }, { - "question": "Retrieve all contacts who are assigned to a sales rep but have not been contacted via email, phone, or meeting in the past three months.", - "sql": "SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN users AS u ON c.owner_id = u.user_id LEFT JOIN emails AS e ON c.contact_id = e.contact_id AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN meetings AS m ON c.contact_id = m.contact_id AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) WHERE e.contact_id IS NULL AND p.contact_id IS NULL AND m.contact_id IS NULL;" + "question": ("Retrieve all contacts who are assigned to a sales rep but have not " + "been contacted via email, phone, or meeting in the past three months."), + "sql": ("SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c " + "JOIN users AS u ON c.owner_id = u.user_id " + "LEFT JOIN emails AS e ON c.contact_id = e.contact_id " + "AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id " + "AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "LEFT JOIN meetings AS m ON c.contact_id = m.contact_id " + "AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "WHERE e.contact_id IS NULL AND p.contact_id IS NULL " + "AND m.contact_id IS NULL;"), }, { - "question": "Which email campaigns resulted in the highest number of closed deals, and what was the average deal size for those campaigns?", - "sql": "SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec JOIN contacts AS c ON ec.campaign_id = c.campaign_id JOIN deals AS d ON c.contact_id = d.contact_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id ORDER BY closed_deals DESC;" + "question": ("Which email campaigns resulted in the highest number of closed deals, " + "and what was the average deal size for those campaigns?"), + "sql": ("SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, " + "AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec " + "JOIN contacts AS c ON ec.campaign_id = c.campaign_id " + "JOIN deals AS d ON c.contact_id = d.contact_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id " + "ORDER BY closed_deals DESC;"), }, { - "question": "Find the average time it takes for a lead to go from creation to conversion into a deal, broken down by industry.", - "sql": "SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) AS avg_conversion_time FROM leads AS l JOIN companies AS co ON l.company_id = co.company_id JOIN industries AS ind ON co.industry_id = ind.industry_id JOIN opportunities AS o ON l.lead_id = o.lead_id JOIN deals AS d ON o.opportunity_id = d.opportunity_id WHERE d.stage_id IN (SELECT stage_id FROM deal_stages WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name ORDER BY avg_conversion_time ASC;" + "question": ("Find the average time it takes for a lead to go from creation to " + "conversion into a deal, broken down by industry."), + "sql": ("SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) " + "AS avg_conversion_time FROM leads AS l " + "JOIN companies AS co ON l.company_id = co.company_id " + "JOIN industries AS ind ON co.industry_id = ind.industry_id " + "JOIN opportunities AS o ON l.lead_id = o.lead_id " + "JOIN deals AS d ON o.opportunity_id = d.opportunity_id " + "WHERE d.stage_id IN (SELECT stage_id FROM deal_stages " + "WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name " + "ORDER BY avg_conversion_time ASC;"), }, { - "question": "Which sales reps (users) have the highest win rate, calculated as the percentage of their assigned leads that convert into closed deals?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate FROM users AS u JOIN leads AS l ON u.user_id = l.owner_id LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id ORDER BY win_rate DESC;" - } + "question": ("Which sales reps (users) have the highest win rate, calculated as " + "the percentage of their assigned leads that convert into closed deals?"), + "sql": ("SELECT u.user_id, u.first_name, u.last_name, " + "COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate " + "FROM users AS u " + "JOIN leads AS l ON u.user_id = l.owner_id " + "LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id " + "LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id " + "ORDER BY win_rate DESC;"), + }, ] diff --git a/api/extensions.py b/api/extensions.py index 0c40559d..01dd2e1a 100644 --- a/api/extensions.py +++ b/api/extensions.py @@ -1,13 +1,15 @@ -""" Extensions for the text2sql library """ +"""Extensions for the text2sql library""" + import os + from falkordb import FalkorDB # Connect to FalkorDB url = os.getenv("FALKORDB_URL", None) if url is None: try: - db = FalkorDB(host='localhost', port=6379) + db = FalkorDB(host="localhost", port=6379) except Exception as e: - raise Exception(f"Failed to connect to FalkorDB: {e}") + raise ConnectionError(f"Failed to connect to FalkorDB: {e}") from e else: db = FalkorDB.from_url(os.getenv("FALKORDB_URL")) diff --git a/api/graph.py b/api/graph.py index 6632caef..9eed9bca 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,96 +1,119 @@ -""" Module to handle the graph data loading into the database. """ +"""Module to handle the graph data loading into the database.""" + import json import logging +from itertools import combinations from typing import List, Tuple + from litellm import completion from pydantic import BaseModel + from api.config import Config from api.extensions import db -from itertools import combinations -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + class TableDescription(BaseModel): - """ Table Description """ + """Table Description""" + name: str description: str + class ColumnDescription(BaseModel): - """ Column Description """ + """Column Description""" + name: str description: str + class Descriptions(BaseModel): - """ List of tables """ + """List of tables""" + tables_descriptions: list[TableDescription] columns_descriptions: list[ColumnDescription] -def get_db_description(graph_id: str) -> str: - """ Get the database description from the graph. """ + +def get_db_description(graph_id: str) -> (str, str): + """Get the database description from the graph.""" graph = db.select_graph(graph_id) - query_result = graph.query(""" + query_result = graph.query( + """ MATCH (d:Database) - RETURN d.description + RETURN d.description, d.url """ ) - + if not query_result.result_set: - return "No description available for this database." - - return query_result.result_set[0][0] # Return the first result's description + return ("No description available for this database.", "No URL available for this database.") + + return (query_result.result_set[0][0], query_result.result_set[0][1]) # Return the first result's description + def find( - graph_id: str, - queries_history: List[str], - db_description: str = None + graph_id: str, queries_history: List[str], db_description: str = None ) -> Tuple[bool, List[dict]]: - """ Find the tables and columns relevant to the user's query. """ - + """Find the tables and columns relevant to the user's query.""" + graph = db.select_graph(graph_id) user_query = queries_history[-1] previous_queries = queries_history[:-1] - logging.info(f"Calling to an LLM to find relevant tables and columns for the query: {user_query}") + logging.info( + "Calling to an LLM to find relevant tables and columns for the query: %s", + user_query + ) # Call the completion model to get the relevant Cypher queries to retrieve # from the Graph that represent the Database schema. # The completion model will generate a set of Cypher query to retrieve the relevant nodes. - completion_result = completion(model=Config.COMPLETION_MODEL, - response_format=Descriptions, - messages=[ - { - "content": Config.FIND_SYSTEM_PROMPT.format(db_description=db_description), - "role": "system" - }, - { - "content": json.dumps({ - "previous_user_queries:": previous_queries, - "user_query": user_query - }), - "role": "user" - } - ], - temperature=0, - ) + completion_result = completion( + model=Config.COMPLETION_MODEL, + response_format=Descriptions, + messages=[ + { + "content": Config.FIND_SYSTEM_PROMPT.format(db_description=db_description), + "role": "system", + }, + { + "content": json.dumps( + { + "previous_user_queries:": previous_queries, + "user_query": user_query, + } + ), + "role": "user", + }, + ], + temperature=0, + ) json_str = completion_result.choices[0].message.content # Parse JSON string and convert to Pydantic model json_data = json.loads(json_str) descriptions = Descriptions(**json_data) - logging.info(f"Find tables based on: {descriptions.tables_descriptions}") + logging.info("Find tables based on: %s", descriptions.tables_descriptions) tables_des = _find_tables(graph, descriptions.tables_descriptions) - logging.info(f"Find tables based on columns: {descriptions.columns_descriptions}") + logging.info("Find tables based on columns: %s", descriptions.columns_descriptions) tables_by_columns_des = _find_tables_by_columns(graph, descriptions.columns_descriptions) # table names for sphere and route extraction base_tables_names = [table[0] for table in tables_des] logging.info("Extracting tables by sphere") tables_by_sphere = _find_tables_sphere(graph, base_tables_names) - logging.info(f"Extracting tables by connecting routes {base_tables_names}") + logging.info("Extracting tables by connecting routes %s", base_tables_names) tables_by_route, _ = find_connecting_tables(graph, base_tables_names) - combined_tables = _get_unique_tables(tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere) - - return True, combined_tables, [tables_des, tables_by_columns_des, tables_by_route, tables_by_sphere] + combined_tables = _get_unique_tables( + tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere + ) + + return ( + True, + combined_tables, + [tables_des, tables_by_columns_des, tables_by_route, tables_by_sphere], + ) + def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: @@ -99,7 +122,8 @@ def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: # Get the table node from the graph embedding_result = Config.EMBEDDING_MODEL.embed(table.description) - query_result = graph.query(""" + query_result = graph.query( + """ CALL db.idx.vector.queryNodes( 'Table', 'embedding', @@ -115,20 +139,21 @@ def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: nullable: columns.nullable }) """, - { - 'embedding': embedding_result[0] - }) + {"embedding": embedding_result[0]}, + ) for node in query_result.result_set: if node not in result: result.append(node) - + return result + def _find_tables_sphere(graph, tables: List[str]) -> List[dict]: result = [] for table_name in tables: - query_result = graph.query(""" + query_result = graph.query( + """ MATCH (node:Table {name: $name}) MATCH (node)-[:BELONGS_TO]-(column)-[:REFERENCES]-()-[:BELONGS_TO]-(table_ref) WITH table_ref @@ -141,9 +166,8 @@ def _find_tables_sphere(graph, tables: List[str]) -> List[dict]: nullable: columns.nullable }) """, - { - 'name': table_name - }) + {"name": table_name}, + ) for node in query_result.result_set: if node not in result: result.append(node) @@ -158,7 +182,8 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis # Get the table node from the graph embedding_result = Config.EMBEDDING_MODEL.embed(column.description) - query_result = graph.query(""" + query_result = graph.query( + """ CALL db.idx.vector.queryNodes( 'Column', 'embedding', @@ -178,9 +203,8 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis nullable: columns.nullable }) """, - { - 'embedding': embedding_result[0] - }) + {"embedding": embedding_result[0]}, + ) for node in query_result.result_set: if node not in result: @@ -188,22 +212,23 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis return result + def _get_unique_tables(tables_list): # Dictionary to store unique tables with the table name as the key unique_tables = {} - + for table_info in tables_list: table_name = table_info[0] # The first element is the table name - + # Only add if this table name hasn't been seen before try: if table_name not in unique_tables: table_info[3] = [dict(od) for od in table_info[3]] - table_info[2] = 'Foreign keys: ' + table_info[2] + table_info[2] = "Foreign keys: " + table_info[2] unique_tables[table_name] = table_info - except: - print(f"Error: {table_info}") - + except Exception as e: + print(f"Error: {table_info}, Exception: {e}") + # Return the values (the unique table info lists) return list(unique_tables.values()) @@ -212,11 +237,11 @@ def find_connecting_tables(graph, table_names: List[str]) -> Tuple[List[dict], L """ Find all tables that form connections between any pair of tables in the input list. Handles both Table nodes and Column nodes with primary keys. - + Args: graph: The FalkorDB graph database connection table_names: List of table names to check connections between - + Returns: A set of all table names that form connections between any pair in the input """ @@ -259,5 +284,5 @@ def find_connecting_tables(graph, table_names: List[str]) -> Tuple[List[dict], L target_table.foreign_keys AS foreign_keys, columns """ - result = graph.query(query, {'pairs': pair_params}, timeout=300).result_set - return result, None \ No newline at end of file + result = graph.query(query, {"pairs": pair_params}, timeout=300).result_set + return result, None diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index 289e97e8..01aa1262 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -1,9 +1,17 @@ +""" +CRM data generator module for creating complete database schemas with relationships. + +This module provides functionality to generate comprehensive CRM database schemas +with proper primary/foreign key relationships and table structures. +""" + import json import os import time +from typing import Any, Dict, List, Optional + import requests -from typing import Dict, List, Any, Optional, Set, Tuple -from litellm import completion, validate_environment, utils as litellm_utils +from litellm import completion OUTPUT_FILE = "complete_crm_schema.json" MAX_RETRIES = 3 @@ -14,13 +22,14 @@ "primary_keys": {}, # table_name -> primary_key_column "foreign_keys": {}, # table_name -> {column_name -> (referenced_table, referenced_column)} "processed_tables": set(), # Set of tables that have been processed - "table_relationships": {} # table_name -> set of related tables + "table_relationships": {}, # table_name -> set of related tables } + def load_initial_schema(file_path: str) -> Dict[str, Any]: """Load the initial schema file with table names""" try: - with open(file_path, 'r') as file: + with open(file_path, "r", encoding="utf-8") as file: schema = json.load(file) print(f"Loaded initial schema with {len(schema.get('tables', {}))} tables") return schema @@ -28,145 +37,164 @@ def load_initial_schema(file_path: str) -> Dict[str, Any]: print(f"Error loading schema file: {e}") return {"database": "crm_system", "tables": {}} + def save_schema(schema: Dict[str, Any], output_file: str = OUTPUT_FILE) -> None: """Save the current schema to a file with metadata""" # Add metadata if "metadata" not in schema: schema["metadata"] = {} - + schema["metadata"]["last_updated"] = time.strftime("%Y-%m-%d %H:%M:%S") schema["metadata"]["completed_tables"] = len(key_registry["processed_tables"]) schema["metadata"]["total_tables"] = len(schema.get("tables", {})) schema["metadata"]["key_registry"] = { "primary_keys": key_registry["primary_keys"], "foreign_keys": key_registry["foreign_keys"], - "table_relationships": {k: list(v) for k, v in key_registry["table_relationships"].items()} + "table_relationships": {k: list(v) for k, v in key_registry["table_relationships"].items()}, } - - with open(output_file, 'w') as file: + + with open(output_file, "w", encoding="utf-8") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {output_file}") + def update_key_registry(table_name: str, table_data: Dict[str, Any]) -> None: """Update the key registry with information from a processed table""" # Mark table as processed key_registry["processed_tables"].add(table_name) - + # Track primary keys if "columns" in table_data: for col_name, col_data in table_data["columns"].items(): if col_data.get("key") == "PRI": key_registry["primary_keys"][table_name] = col_name break - + # Track foreign keys and relationships if "foreign_keys" in table_data: if table_name not in key_registry["foreign_keys"]: key_registry["foreign_keys"][table_name] = {} - + if table_name not in key_registry["table_relationships"]: key_registry["table_relationships"][table_name] = set() - - for fk_name, fk_data in table_data["foreign_keys"].items(): + + for fk_data in table_data["foreign_keys"].values(): column = fk_data.get("column") ref_table = fk_data.get("referenced_table") ref_column = fk_data.get("referenced_column") - + if column and ref_table and ref_column: - key_registry["foreign_keys"][table_name][column] = (ref_table, ref_column) - + key_registry["foreign_keys"][table_name][column] = ( + ref_table, + ref_column, + ) + # Update relationships key_registry["table_relationships"][table_name].add(ref_table) - + # Ensure the referenced table has an entry if ref_table not in key_registry["table_relationships"]: key_registry["table_relationships"][ref_table] = set() - + # Add the reverse relationship key_registry["table_relationships"][ref_table].add(table_name) + def find_related_tables(table_name: str, all_tables: List[str]) -> List[str]: """Find tables that might be related to the current table""" related = [] - + # Check registry first for already established relationships if table_name in key_registry["table_relationships"]: related.extend(key_registry["table_relationships"][table_name]) - + # Extract base name - base_parts = table_name.split('_') - + base_parts = table_name.split("_") + for other_table in all_tables: if other_table == table_name or other_table in related: continue - + # Direct naming relationship if table_name in other_table or other_table in table_name: related.append(other_table) continue - + # Check for common roots - other_parts = other_table.split('_') + other_parts = other_table.split("_") for part in base_parts: if part in other_parts and len(part) > 3: # Avoid short common words related.append(other_table) break - + return list(set(related)) # Remove duplicates -def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology) -> str: + +def get_table_prompt( + table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology +) -> str: """Generate a prompt for the LLM to create a table schema with proper relationships""" existing_tables = schema.get("tables", {}) - + # Find related tables related_tables = find_related_tables(table_name, all_table_names) related_tables_str = ", ".join(related_tables) if related_tables else "None identified yet" - - # Suggest primary key pattern - table_base = table_name.split('_')[0] if '_' in table_name else table_name - suggested_pk = f"{table_name}_id" # Default pattern - - # Check if related tables have primary keys to follow same pattern - for related in related_tables: - if related in key_registry["primary_keys"]: - related_pk = key_registry["primary_keys"][related] - if related_pk.endswith('_id') and related in related_pk: - # Follow the same pattern - suggested_pk = f"{table_name}_id" - break - + + # # Suggest primary key pattern + # table_base = table_name.split("_")[0] if "_" in table_name else table_name + # suggested_pk = f"{table_name}_id" # Default pattern + + # # Check if related tables have primary keys to follow same pattern + # for related in related_tables: + # if related in key_registry["primary_keys"]: + # related_pk = key_registry["primary_keys"][related] + # if related_pk.endswith("_id") and related in related_pk: + # # Follow the same pattern + # suggested_pk = f"{table_name}_id" + # break + # Prepare foreign key suggestions fk_suggestions = [] for related in related_tables: if related in key_registry["primary_keys"]: - fk_suggestions.append({ - "column": f"{related}_id", - "referenced_table": related, - "referenced_column": key_registry["primary_keys"][related] - }) - + fk_suggestions.append( + { + "column": f"{related}_id", + "referenced_table": related, + "referenced_column": key_registry["primary_keys"][related], + } + ) + fk_suggestions_str = "" if fk_suggestions: fk_suggestions_str = "Consider these foreign key relationships:\n" for i, fk in enumerate(fk_suggestions[:5]): # Limit to 5 suggestions - fk_suggestions_str += f"{i+1}. {fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}\n" - + fk_suggestions_str += ( + f"{i+1}. {fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}\n" + ) + # Include examples of related tables that have been processed related_examples = "" example_count = 0 for related in related_tables: - if (related in existing_tables and - isinstance(existing_tables[related], dict) and - 'columns' in existing_tables[related] and - example_count < 2): - related_examples += f"\nRelated table example:\n```json\n{json.dumps({related: existing_tables[related]}, indent=2)}\n```\n" + if ( + related in existing_tables + and isinstance(existing_tables[related], dict) + and "columns" in existing_tables[related] + and example_count < 2 + ): + related_examples += ( + f"\nRelated table example:\n```json\n" + f"{json.dumps({related: existing_tables[related]}, indent=2)}\n```\n" + ) example_count += 1 - + # Use contacts table as primary example if no related examples found contacts_example = """ { "contacts": { - "description": "Stores information about individual contacts within the CRM system, including personal details and relationship to companies.", + "description": ("Stores information about individual contacts within the CRM " + "system, including personal details and relationship to companies."), "columns": { "contact_id": { "description": "Unique identifier for each contact", @@ -264,9 +292,10 @@ def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: L """ # Create context about the table's purpose table_context = get_table_context(table_name, related_tables) - keys = json.dumps(topology['tables'][table_name]) + keys = json.dumps(topology["tables"][table_name]) prompt = f""" -You are an expert database architect specializing in CRM systems. Create a detailed JSON schema for the '{table_name}' table in our CRM database. +You are an expert database architect specializing in CRM systems. Create a detailed +JSON schema for the '{table_name}' table in our CRM database. CONTEXT ABOUT THIS TABLE: {table_context} @@ -314,7 +343,8 @@ def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: L - For many-to-many relationships, create appropriate junction tables - Ensure referential integrity with foreign key constraints -Return ONLY valid JSON for the '{table_name}' table structure without any explanation or additional text: +Return ONLY valid JSON for the '{table_name}' table structure without any +explanation or additional text: {{ "{table_name}": {{ "description": "...", @@ -326,11 +356,12 @@ def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: L """ return prompt + def get_table_context(table_name: str, related_tables: List[str]) -> str: """Generate contextual information about a table based on its name and related tables""" # Extract words from table name - words = table_name.replace('_', ' ').split() - + words = table_name.replace("_", " ").split() + # Common CRM entities entities = { "contact": "Contains information about individuals", @@ -349,9 +380,9 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: "order": "Contains information about customer orders", "subscription": "Contains information about recurring subscriptions", "ticket": "Contains information about support tickets", - "campaign": "Contains information about marketing campaigns" + "campaign": "Contains information about marketing campaigns", } - + # Common relationship patterns relationship_patterns = { "tags": "This is a tagging or categorization table that likely links to various entities", @@ -369,17 +400,18 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: "attachments": "This contains file attachments", "performance": "This tracks performance metrics", "feedback": "This contains feedback information", - "settings": "This contains configuration settings" + "settings": "This contains configuration settings", } - + context = f"The '{table_name}' table appears to be " - + # Check if this is a junction/linking table - if "_" in table_name and not any(p in table_name for p in relationship_patterns.keys()): + if "_" in table_name and not any(p in table_name for p in relationship_patterns): parts = table_name.split("_") if len(parts) == 2 and all(len(p) > 2 for p in parts): - return f"This appears to be a junction table linking '{parts[0]}' and '{parts[1]}', likely with a many-to-many relationship." - + return (f"This appears to be a junction table linking '{parts[0]}' and " + f"'{parts[1]}', likely with a many-to-many relationship.") + # Check for main entities for entity, description in entities.items(): if entity in words: @@ -387,56 +419,65 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: break else: context += "part of the CRM system. " - + # Check for relationship patterns for pattern, description in relationship_patterns.items(): if pattern in table_name: context += f"{description}. " break - + # Add related tables info if related_tables: - context += f"It appears to be related to the following tables: {', '.join(related_tables)}. " - + context += ( + f"It appears to be related to the following tables: {', '.join(related_tables)}. " + ) + # Guess if it's a child table for related in related_tables: if related in table_name and len(related) < len(table_name): context += f"It may be a child or detail table for the {related} table. " break - + return context + def call_llm_api(prompt: str, retries: int = MAX_RETRIES) -> Optional[str]: """Call the LLM API with the given prompt, with retry logic""" for attempt in range(1, retries + 1): try: config = {} - config['temperature'] = 0.5 - config['response_format'] = { "type": "json_object" } - - + config["temperature"] = 0.5 + config["response_format"] = {"type": "json_object"} + response = completion( model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": prompt}], - **config + **config, + ) + result = ( + response.json() + .get("choices", [{}])[0] + .get("message", "") + .get("content", "") + .strip() ) - result = response.json().get("choices", [{}])[0].get("message", "").get("content", "").strip() if result: return result - else: - print(f"Empty response from API (attempt {attempt}/{retries})") - + + print(f"Empty response from API (attempt {attempt}/{retries})") + except requests.exceptions.RequestException as e: print(f"API request error (attempt {attempt}/{retries}): {e}") - + if attempt < retries: sleep_time = RETRY_DELAY * attempt print(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) - + print("All retry attempts failed") return None + def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any]]: """Parse the LLM response and extract the table schema with validation""" try: @@ -445,23 +486,23 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any response = response.split("```json")[1].split("```")[0].strip() elif "```" in response: response = response.split("```")[1].strip() - + # Handle common formatting issues - response = response.replace('\n', ' ').replace('\r', ' ') - + response = response.replace("\n", " ").replace("\r", " ") + # Cleanup any trailing/leading text - start_idx = response.find('{') - end_idx = response.rfind('}') + 1 - if start_idx >= 0 and end_idx > start_idx: + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + if 0 <= start_idx < end_idx: response = response[start_idx:end_idx] - + parsed = json.loads(response) - + # Validation of required components if table_name in parsed: table_data = parsed[table_name] required_keys = ["description", "columns", "indexes", "foreign_keys"] - + # Check if all required sections exist if all(key in table_data for key in required_keys): # Verify columns have required attributes @@ -469,86 +510,89 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any required_col_attrs = ["description", "type", "null"] if not all(attr in col_data for attr in required_col_attrs): print(f"Warning: Column {col_name} is missing required attributes") - + return {table_name: table_data} - else: - missing = [key for key in required_keys if key not in table_data] - print(f"Warning: Table schema missing required sections: {missing}") - return {table_name: table_data} # Return anyway, but with warning - else: - # Try to get the first key if table_name is not found - first_key = next(iter(parsed)) - print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") - return {table_name: parsed[first_key]} + + missing = [key for key in required_keys if key not in table_data] + print(f"Warning: Table schema missing required sections: {missing}") + return {table_name: table_data} # Return anyway, but with warning + + # Try to get the first key if table_name is not found + first_key = next(iter(parsed)) + print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") + return {table_name: parsed[first_key]} except Exception as e: print(f"Error parsing LLM response for {table_name}: {e}") print(f"Raw response: {response[:500]}...") # Show first 500 chars return None -def process_table(table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology) -> Dict[str, Any]: + +def process_table( + table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology +) -> Dict[str, Any]: """Process a single table and update the schema""" print(f"Processing table: {table_name}") - + # Skip if table already has detailed schema - if (table_name in schema["tables"] and - isinstance(schema["tables"][table_name], dict) and - "columns" in schema["tables"][table_name] and - "indexes" in schema["tables"][table_name] and - "foreign_keys" in schema["tables"][table_name]): + if ( + table_name in schema["tables"] + and isinstance(schema["tables"][table_name], dict) + and "columns" in schema["tables"][table_name] + and "indexes" in schema["tables"][table_name] + and "foreign_keys" in schema["tables"][table_name] + ): print(f"Table {table_name} already processed. Skipping.") return schema - + # Generate prompt for this table prompt = get_table_prompt(table_name, schema["tables"], all_table_names, topology) - + # Call LLM API response = call_llm_api(prompt) if not response: print(f"Failed to get response for {table_name}. Skipping.") return schema - + # Parse response table_schema = parse_llm_response(response, table_name) if not table_schema: print(f"Failed to parse response for {table_name}. Skipping.") return schema - + # Update schema schema["tables"].update(table_schema) print(f"Successfully processed {table_name}") - + # Save intermediate results save_schema(schema, f"intermediate_{table_name.replace('/', '_')}.json") - + return schema + def main(): + """Main function to generate complete CRM schema with relationships.""" # Load the initial schema with table names initial_schema_path = "examples/crm_tables.json" # Replace with your actual file path initial_schema = load_initial_schema(initial_schema_path) - + # Get the list of tables to process tables = list(initial_schema.get("tables", {}).keys()) all_table_names = tables.copy() # Keep a full list for reference - - topology = generate_keys(tables) + topology = generate_keys(tables) # Initialize our working schema - schema = { - "database": initial_schema.get("database", "crm_system"), - "tables": {} - } - + schema = {"database": initial_schema.get("database", "crm_system"), "tables": {}} + # If we have existing work, load it if os.path.exists(OUTPUT_FILE): try: - with open(OUTPUT_FILE, 'r') as file: + with open(OUTPUT_FILE, "r", encoding="utf-8") as file: schema = json.load(file) print(f"Loaded existing schema from {OUTPUT_FILE}") except Exception as e: print(f"Error loading existing schema: {e}") - + # Prioritize tables to process - process base tables first def table_priority(table_name): # Base tables should be processed first @@ -559,38 +603,47 @@ def table_priority(table_name): return 2 # Related tables in the middle return 1 - + # Sort tables by priority tables.sort(key=table_priority) - + # Process tables for i, table_name in enumerate(tables): - print(f"\nProcessing table {i+1}/{len(tables)}: {table_name} (Priority: {table_priority(table_name)})") + print( + f"\nProcessing table {i+1}/{len(tables)}: {table_name} " + f"(Priority: {table_priority(table_name)})" + ) schema = process_table(table_name, schema, all_table_names, topology) - + # Save progress after each table save_schema(schema) - + # Add delay to avoid rate limits if i < len(tables) - 1: delay = 2 + (0.5 * i % 5) # Varied delay to help avoid pattern detection print(f"Waiting {delay} seconds before next request...") time.sleep(delay) - + print(f"\nCompleted processing all {len(tables)} tables") print(f"Final schema saved to {OUTPUT_FILE}") - + # Validate the final schema validate_schema(schema) + def generate_keys(tables) -> Dict[str, Any]: + """Generate primary and foreign keys for CRM tables.""" path = "examples/crm_topology.json" + last_key = 0 # Initialize default value + schema = {"tables": {}} # Initialize default schema + # If we have existing work, load it if os.path.exists(path): try: - with open(path, 'r') as file: + with open(path, "r", encoding="utf-8") as file: schema = json.load(file) - last_key = tables.index(list(schema['tables'].keys())[-1]) + if schema.get("tables"): + last_key = tables.index(list(schema["tables"].keys())[-1]) print(f"Loaded existing schema from {path}") except Exception as e: print(f"Error loading existing schema: {e}") @@ -618,9 +671,9 @@ def generate_keys(tables) -> Dict[str, Any]: p = prompt.format(table_name=table, tables=tables) response = call_llm_api(p) new_table = json.loads(response) - schema['tables'].update(new_table) + schema["tables"].update(new_table) - with open(path, 'w') as file: + with open(path, "w", encoding="utf-8") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {path}") print(f"Final schema saved to {path}") @@ -632,27 +685,30 @@ def validate_schema(schema: Dict[str, Any]) -> None: """Perform final validation on the complete schema""" print("\nValidating schema...") issues = [] - + table_count = len(schema["tables"]) - tables_with_columns = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "columns" in t) - tables_with_indexes = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "indexes" in t) - tables_with_foreign_keys = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "foreign_keys" in t) - + tables_with_columns = sum( + 1 for t in schema["tables"].values() if isinstance(t, dict) and "columns" in t + ) + tables_with_indexes = sum( + 1 for t in schema["tables"].values() if isinstance(t, dict) and "indexes" in t + ) + tables_with_foreign_keys = sum( + 1 for t in schema["tables"].values() if isinstance(t, dict) and "foreign_keys" in t + ) + print(f"Total tables: {table_count}") print(f"Tables with columns: {tables_with_columns}") print(f"Tables with indexes: {tables_with_indexes}") print(f"Tables with foreign keys: {tables_with_foreign_keys}") - + # Check if all tables have required sections incomplete_tables = [] for table_name, table_data in schema["tables"].items(): if not isinstance(table_data, dict): incomplete_tables.append(f"{table_name} (empty)") continue - + missing = [] if "description" not in table_data or not table_data["description"]: missing.append("description") @@ -662,10 +718,10 @@ def validate_schema(schema: Dict[str, Any]) -> None: missing.append("indexes") if "foreign_keys" not in table_data: # Can be empty, just needs to exist missing.append("foreign_keys") - + if missing: incomplete_tables.append(f"{table_name} (missing: {', '.join(missing)})") - + if incomplete_tables: issues.append(f"Incomplete tables: {len(incomplete_tables)}") print("Incomplete tables:") @@ -673,17 +729,17 @@ def validate_schema(schema: Dict[str, Any]) -> None: print(f" - {table}") if len(incomplete_tables) > 10: print(f" ... and {len(incomplete_tables) - 10} more") - + # Check foreign key references invalid_fks = [] for table_name, table_data in schema["tables"].items(): if not isinstance(table_data, dict) or "foreign_keys" not in table_data: continue - + for fk_name, fk_data in table_data["foreign_keys"].items(): ref_table = fk_data.get("referenced_table") ref_column = fk_data.get("referenced_column") - + if ref_table and ref_table not in schema["tables"]: invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (table not found)") elif ref_table and ref_column: @@ -691,8 +747,10 @@ def validate_schema(schema: Dict[str, Any]) -> None: if not isinstance(ref_table_data, dict) or "columns" not in ref_table_data: invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (no columns)") elif ref_column not in ref_table_data.get("columns", {}): - invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table}.{ref_column} (column not found)") - + invalid_fks.append( + f"{table_name}.{fk_name} -> {ref_table}.{ref_column} (column not found)" + ) + if invalid_fks: issues.append(f"Invalid foreign keys: {len(invalid_fks)}") print("Invalid foreign keys:") @@ -700,11 +758,12 @@ def validate_schema(schema: Dict[str, Any]) -> None: print(f" - {fk}") if len(invalid_fks) > 10: print(f" ... and {len(invalid_fks) - 10} more") - + if issues: print(f"\nValidation complete. Found {len(issues)} issue types.") else: print("\nValidation complete. No issues found!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/api/index.py b/api/index.py index bc7c3ac8..e0bdf743 100644 --- a/api/index.py +++ b/api/index.py @@ -1,93 +1,248 @@ -""" This module contains the routes for the text2sql API. """ +"""This module contains the routes for the text2sql API.""" + import json -import os import logging +import os +import time +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError from functools import wraps + from dotenv import load_dotenv -from flask import Blueprint, Response, jsonify, render_template, request, stream_with_context, Flask -from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError -from api.graph import find, get_db_description +from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context, g +from flask import session, redirect, url_for +from flask_dance.contrib.google import make_google_blueprint, google +from flask_dance.contrib.github import make_github_blueprint, github +from flask_dance.consumer.storage.session import SessionStorage +from flask_dance.consumer import oauth_authorized + +from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent from api.extensions import db +from api.graph import find, get_db_description from api.loaders.csv_loader import CSVLoader from api.loaders.json_loader import JSONLoader +from api.loaders.postgres_loader import PostgresLoader from api.loaders.odata_loader import ODataLoader -from api.agents import RelevancyAgent, AnalysisAgent -from api.constants import BENCHMARK, EXAMPLES -import random # Load environment variables from .env file load_dotenv() -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Use the same delimiter as in the JavaScript -MESSAGE_DELIMITER = '|||FALKORDB_MESSAGE_BOUNDARY|||' +MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" main = Blueprint("main", __name__) -SECRET_TOKEN = os.getenv('SECRET_TOKEN') -SECRET_TOKEN_ERP = os.getenv('SECRET_TOKEN_ERP') -def verify_token(token): - """ Verify the token provided in the request """ - return token == SECRET_TOKEN or token == SECRET_TOKEN_ERP or token == "null" +def validate_and_cache_user(): + """ + Helper function to validate OAuth token and cache user info. + Returns (user_info, is_authenticated) tuple. + Supports both Google and GitHub OAuth. + """ + # Check for cached user info from either provider + user_info = session.get("user_info") + token_validated_at = session.get("token_validated_at", 0) + current_time = time.time() + + # Use cached user info if it's less than 15 minutes old + if user_info and (current_time - token_validated_at) < 900: # 15 minutes + return user_info, True + + # Check Google OAuth first + if google.authorized: + try: + resp = google.get("/oauth2/v2/userinfo") + if resp.ok: + google_user = resp.json() + # Normalize user info structure + user_info = { + "id": google_user.get("id"), + "name": google_user.get("name"), + "email": google_user.get("email"), + "picture": google_user.get("picture"), + "provider": "google" + } + session["user_info"] = user_info + session["token_validated_at"] = current_time + return user_info, True + except Exception as e: + logging.warning(f"Google OAuth validation error: {e}") + + # Check GitHub OAuth + if github.authorized: + try: + # Get user profile + resp = github.get("/user") + if resp.ok: + github_user = resp.json() + + # Get user email (GitHub may require separate call for email) + email_resp = github.get("/user/emails") + email = None + if email_resp.ok: + emails = email_resp.json() + # Find primary email + for email_obj in emails: + if email_obj.get("primary", False): + email = email_obj.get("email") + break + # If no primary email found, use the first one + if not email and emails: + email = emails[0].get("email") + + # Normalize user info structure + user_info = { + "id": str(github_user.get("id")), # Convert to string for consistency + "name": github_user.get("name") or github_user.get("login"), + "email": email, + "picture": github_user.get("avatar_url"), + "provider": "github" + } + session["user_info"] = user_info + session["token_validated_at"] = current_time + return user_info, True + except Exception as e: + logging.warning(f"GitHub OAuth validation error: {e}") + + # If no valid authentication found, clear session + session.clear() + return None, False + def token_required(f): - """ Decorator to protect routes with token authentication """ + """Decorator to protect routes with token authentication""" + @wraps(f) def decorated_function(*args, **kwargs): - token = request.args.get('token', 'null') # Get token from header - os.environ["USER_TOKEN"] = token - if not verify_token(token): - return jsonify(message="Unauthorized"), 401 + user_info, is_authenticated = validate_and_cache_user() + + if not is_authenticated: + return jsonify(message="Unauthorized - Please log in"), 401 + + g.user_id = user_info.get("id") + if not g.user_id: + session.clear() + return jsonify(message="Unauthorized - Invalid user"), 401 + return f(*args, **kwargs) + return decorated_function + app = Flask(__name__) +app.secret_key = os.getenv("FLASK_SECRET_KEY", "supersekrit") + +# Google OAuth setup +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") +google_bp = make_google_blueprint( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + scope=[ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "openid" + ] +) +app.register_blueprint(google_bp, url_prefix="/login") + +# GitHub OAuth setup +GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") +GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") +github_bp = make_github_blueprint( + client_id=GITHUB_CLIENT_ID, + client_secret=GITHUB_CLIENT_SECRET, + scope="user:email", + storage=SessionStorage() +) +app.register_blueprint(github_bp, url_prefix="/login") + +# GitHub OAuth signal handler + +@oauth_authorized.connect_via(github_bp) +def github_logged_in(blueprint, token): + if not token: + return False + + try: + # Get user profile + resp = github.get("/user") + if resp.ok: + github_user = resp.json() + + # Get user email (GitHub may require separate call for email) + email_resp = github.get("/user/emails") + email = None + if email_resp.ok: + emails = email_resp.json() + # Find primary email + for email_obj in emails: + if email_obj.get("primary", False): + email = email_obj.get("email") + break + # If no primary email found, use the first one + if not email and emails: + email = emails[0].get("email") + + # Normalize user info structure + user_info = { + "id": str(github_user.get("id")), # Convert to string for consistency + "name": github_user.get("name") or github_user.get("login"), + "email": email, + "picture": github_user.get("avatar_url"), + "provider": "github" + } + session["user_info"] = user_info + session["token_validated_at"] = time.time() + return False # Don't create default flask-dance entry in session -# @app.before_request -# def before_request_func(): -# oidc_token = request.headers.get('x-vercel-oidc-token') -# if oidc_token: -# set_oidc_token(oidc_token) -# credentials = assume_role() -# else: -# # Optional: require it for protected routes -# pass - -@app.route('/') -@token_required # Apply token authentication decorator + except Exception as e: + logging.error(f"GitHub OAuth signal error: {e}") + + return False + + +@app.errorhandler(Exception) +def handle_oauth_error(error): + """Handle OAuth-related errors gracefully""" + # Check if it's an OAuth-related error + if "token" in str(error).lower() or "oauth" in str(error).lower(): + logging.warning(f"OAuth error occurred: {error}") + session.clear() + return redirect(url_for("home")) + + # For other errors, let them bubble up + raise error + + +@app.route("/") def home(): - """ Home route """ - return render_template('chat.html') + """Home route""" + user_info, is_authenticated = validate_and_cache_user() -@app.route('/graphs') -@token_required # Apply token authentication decorator + # If not authenticated through OAuth, check for any stale session data + if not is_authenticated and not google.authorized and not github.authorized: + session.pop("user_info", None) + + return render_template("chat.j2", is_authenticated=is_authenticated, user_info=user_info) + + +@app.route("/graphs") +@token_required def graphs(): """ This route is used to list all the graphs that are available in the database. """ - graphs = db.list_graphs() - if os.getenv("USER_TOKEN") == SECRET_TOKEN: - if 'hospital' in graphs: - return ['hospital'] - else: - return [] - - if os.getenv("USER_TOKEN") == SECRET_TOKEN_ERP: - if 'ERP_system' in graphs: - return ['ERP_system'] - else: - return ['crm_usecase'] - elif os.getenv("USER_TOKEN") == "null": - if 'crm_usecase' in graphs: - return ['crm_usecase'] - else: - return [] - else: - graphs.remove('hospital') - return graphs + user_id = g.user_id + user_graphs = db.list_graphs() + # Only include graphs that start with user_id + '_', and strip the prefix + filtered_graphs = [graph[len(f"{user_id}_"):] + for graph in user_graphs if graph.startswith(f"{user_id}_")] + return jsonify(filtered_graphs) + @app.route("/graphs", methods=["POST"]) -@token_required # Apply token authentication decorator +@token_required def load(): """ This route is used to load the graph data into the database. @@ -106,20 +261,20 @@ def load(): if not data or "database" not in data: return jsonify({"error": "Invalid JSON data"}), 400 - graph_id = data["database"] + graph_id = g.user_id + "_" + data["database"] success, result = JSONLoader.load(graph_id, data) - # ✅ Handle XML Payload - elif content_type.startswith("application/xml") or content_type.startswith("text/xml"): - xml_data = request.data - graph_id = "" - success, result = ODataLoader.load(graph_id, xml_data) + # # ✅ Handle XML Payload + # elif content_type.startswith("application/xml") or content_type.startswith("text/xml"): + # xml_data = request.data + # graph_id = "" + # success, result = ODataLoader.load(graph_id, xml_data) - # ✅ Handle CSV Payload - elif content_type.startswith("text/csv"): - csv_data = request.data - graph_id = "" - success, result = CSVLoader.load(graph_id, csv_data) + # # ✅ Handle CSV Payload + # elif content_type.startswith("text/csv"): + # csv_data = request.data + # graph_id = "" + # success, result = CSVLoader.load(graph_id, csv_data) # ✅ Handle File Upload (FormData with JSON/XML) elif content_type.startswith("multipart/form-data"): @@ -134,7 +289,7 @@ def load(): if file.filename.endswith(".json"): try: data = json.load(file) - graph_id = data.get("database", "") + graph_id = g.user_id + "_" + data.get("database", "") success, result = JSONLoader.load(graph_id, data) except json.JSONDecodeError: return jsonify({"error": "Invalid JSON file"}), 400 @@ -142,15 +297,15 @@ def load(): # ✅ Check if file is XML elif file.filename.endswith(".xml"): xml_data = file.read().decode("utf-8") # Convert bytes to string - graph_id = file.filename.replace(".xml", "") + graph_id = g.user_id + "_" + file.filename.replace(".xml", "") success, result = ODataLoader.load(graph_id, xml_data) - + # ✅ Check if file is csv elif file.filename.endswith(".csv"): csv_data = file.read().decode("utf-8") # Convert bytes to string - graph_id = file.filename.replace(".csv", "") + graph_id = g.user_id + "_" + file.filename.replace(".csv", "") success, result = CSVLoader.load(graph_id, csv_data) - + else: return jsonify({"error": "Unsupported file type"}), 415 else: @@ -162,94 +317,417 @@ def load(): return jsonify({"error": result}), 400 + @app.route("/graphs/", methods=["POST"]) -@token_required # Apply token authentication decorator +@token_required def query(graph_id: str): """ text2sql """ + graph_id = g.user_id + "_" + graph_id.strip() request_data = request.get_json() queries_history = request_data.get("chat") result_history = request_data.get("result") instructions = request_data.get("instructions") if not queries_history: return jsonify({"error": "Invalid or missing JSON data"}), 400 - - logging.info(f"User Query: {queries_history[-1]}") + + logging.info("User Query: %s", queries_history[-1]) # Create a generator function for streaming def generate(): agent_rel = RelevancyAgent(queries_history, result_history) agent_an = AnalysisAgent(queries_history, result_history) - - step = {"type": "reasoning_step", "message": "Step 1: Analyzing the user query"} + step = {"type": "reasoning_step", "message": "Step 1: Analyzing user query and generating SQL..."} yield json.dumps(step) + MESSAGE_DELIMITER - db_description = get_db_description(graph_id) # Ensure the database description is loaded - - logging.info(f"Calling to relvancy agent with query: {queries_history[-1]}") + db_description, db_url = get_db_description(graph_id) # Ensure the database description is loaded + + logging.info("Calling to relvancy agent with query: %s", queries_history[-1]) answer_rel = agent_rel.get_answer(queries_history[-1], db_description) if answer_rel["status"] != "On-topic": - step = {"type": "followup_questions", "message": "Off topic question: " + answer_rel["reason"]} - logging.info(f"SQL Fail reason: {answer_rel["reason"]}") + step = { + "type": "followup_questions", + "message": "Off topic question: " + answer_rel["reason"], + } + logging.info("SQL Fail reason: %s", answer_rel["reason"]) yield json.dumps(step) + MESSAGE_DELIMITER else: # Use a thread pool to enforce timeout with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(find, graph_id, queries_history, db_description) try: - success, result, _ = future.result(timeout=120) + _, result, _ = future.result(timeout=120) except FuturesTimeoutError: - yield json.dumps({"type": "error", "message": "Timeout error while finding tables relevant to your request."}) + MESSAGE_DELIMITER + yield json.dumps( + { + "type": "error", + "message": ("Timeout error while finding tables relevant to " + "your request."), + } + ) + MESSAGE_DELIMITER return except Exception as e: - logging.info(f"Error in find function: {e}") - yield json.dumps({"type": "error", "message": "Error in find function"}) + MESSAGE_DELIMITER + logging.info("Error in find function: %s", e) + yield json.dumps( + {"type": "error", "message": "Error in find function"} + ) + MESSAGE_DELIMITER return - step = {"type": "reasoning_step", - "message": "Step 2: Generating SQL query"} - yield json.dumps(step) + MESSAGE_DELIMITER - logging.info(f"Calling to analysis agent with query: {queries_history[-1]}") - answer_an = agent_an.get_analysis(queries_history[-1], result, db_description, instructions) + logging.info("Calling to analysis agent with query: %s", queries_history[-1]) + answer_an = agent_an.get_analysis( + queries_history[-1], result, db_description, instructions + ) + + logging.info("SQL Result: %s", answer_an['sql_query']) + yield json.dumps( + { + "type": "final_result", + "data": answer_an["sql_query"], + "conf": answer_an["confidence"], + "miss": answer_an["missing_information"], + "amb": answer_an["ambiguities"], + "exp": answer_an["explanation"], + "is_valid": answer_an["is_sql_translatable"], + } + ) + MESSAGE_DELIMITER + + # If the SQL query is valid, execute it using the postgress database db_url + if answer_an["is_sql_translatable"]: + # Check if this is a destructive operation that requires confirmation + sql_query = answer_an["sql_query"] + sql_type = sql_query.strip().split()[0].upper() if sql_query else "" + + if sql_type in ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE']: + # This is a destructive operation - ask for user confirmation + confirmation_message = f"""⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ + +The generated SQL query will perform a **{sql_type}** operation: - logging.info(f"SQL Result: {answer_an['sql_query']}") - yield json.dumps({"type": "final_result", "data": answer_an['sql_query'], "conf": answer_an['confidence'], - "miss": answer_an['missing_information'], - "amb": answer_an['ambiguities'], - "exp": answer_an['explanation'], - "is_valid": answer_an['is_sql_translatable']}) + MESSAGE_DELIMITER +SQL: +{sql_query} - return Response(stream_with_context(generate()), content_type='application/json') +What this will do: +""" + if sql_type == 'INSERT': + confirmation_message += "• Add new data to the database" + elif sql_type == 'UPDATE': + confirmation_message += "• Modify existing data in the database" + elif sql_type == 'DELETE': + confirmation_message += "• **PERMANENTLY DELETE** data from the database" + elif sql_type == 'DROP': + confirmation_message += "• **PERMANENTLY DELETE** entire tables or database objects" + elif sql_type == 'CREATE': + confirmation_message += "• Create new tables or database objects" + elif sql_type == 'ALTER': + confirmation_message += "• Modify the structure of existing tables" + elif sql_type == 'TRUNCATE': + confirmation_message += "• **PERMANENTLY DELETE ALL DATA** from specified tables" -@app.route('/suggestions') -@token_required # Apply token authentication decorator -def suggestions(): + confirmation_message += """ + +⚠️ WARNING: This operation will make changes to your database and may be irreversible. +""" + + yield json.dumps( + { + "type": "destructive_confirmation", + "message": confirmation_message, + "sql_query": sql_query, + "operation_type": sql_type + } + ) + MESSAGE_DELIMITER + return # Stop here and wait for user confirmation + + try: + step = {"type": "reasoning_step", "message": "Step 2: Executing SQL query"} + yield json.dumps(step) + MESSAGE_DELIMITER + + # Check if this query modifies the database schema + is_schema_modifying, operation_type = PostgresLoader.is_schema_modifying_query(sql_query) + + query_results = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) + yield json.dumps( + { + "type": "query_result", + "data": query_results, + } + ) + MESSAGE_DELIMITER + + # If schema was modified, refresh the graph + if is_schema_modifying: + step = {"type": "reasoning_step", "message": "Step 3: Schema change detected - refreshing graph..."} + yield json.dumps(step) + MESSAGE_DELIMITER + + refresh_success, refresh_message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if refresh_success: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"✅ Schema change detected ({operation_type} operation)\n\n🔄 Graph schema has been automatically refreshed with the latest database structure.", + "refresh_status": "success" + } + ) + MESSAGE_DELIMITER + else: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"⚠️ Schema was modified but graph refresh failed: {refresh_message}", + "refresh_status": "failed" + } + ) + MESSAGE_DELIMITER + + # Generate user-readable response using AI + step_num = "4" if is_schema_modifying else "3" + step = {"type": "reasoning_step", "message": f"Step {step_num}: Generating user-friendly response"} + yield json.dumps(step) + MESSAGE_DELIMITER + + response_agent = ResponseFormatterAgent() + user_readable_response = response_agent.format_response( + user_query=queries_history[-1], + sql_query=answer_an["sql_query"], + query_results=query_results, + db_description=db_description + ) + + yield json.dumps( + { + "type": "ai_response", + "message": user_readable_response, + } + ) + MESSAGE_DELIMITER + + except Exception as e: + logging.error("Error executing SQL query: %s", e) + yield json.dumps( + {"type": "error", "message": str(e)} + ) + MESSAGE_DELIMITER + + return Response(stream_with_context(generate()), content_type="application/json") + + +@app.route("/graphs//confirm", methods=["POST"]) +@token_required +def confirm_destructive_operation(graph_id: str): + """ + Handle user confirmation for destructive SQL operations + """ + graph_id = g.user_id + "_" + graph_id.strip() + request_data = request.get_json() + confirmation = request_data.get("confirmation", "").strip().upper() + sql_query = request_data.get("sql_query", "") + queries_history = request_data.get("chat", []) + + if not sql_query: + return jsonify({"error": "No SQL query provided"}), 400 + + # Create a generator function for streaming the confirmation response + def generate_confirmation(): + if confirmation == "CONFIRM": + try: + db_description, db_url = get_db_description(graph_id) + + step = {"type": "reasoning_step", "message": "Step 2: Executing confirmed SQL query"} + yield json.dumps(step) + MESSAGE_DELIMITER + + # Check if this query modifies the database schema + is_schema_modifying, operation_type = PostgresLoader.is_schema_modifying_query(sql_query) + + query_results = PostgresLoader.execute_sql_query(sql_query, db_url) + yield json.dumps( + { + "type": "query_result", + "data": query_results, + } + ) + MESSAGE_DELIMITER + + # If schema was modified, refresh the graph + if is_schema_modifying: + step = {"type": "reasoning_step", "message": "Step 3: Schema change detected - refreshing graph..."} + yield json.dumps(step) + MESSAGE_DELIMITER + + refresh_success, refresh_message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if refresh_success: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"✅ Schema change detected ({operation_type} operation)\n\n🔄 Graph schema has been automatically refreshed with the latest database structure.", + "refresh_status": "success" + } + ) + MESSAGE_DELIMITER + else: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"⚠️ Schema was modified but graph refresh failed: {refresh_message}", + "refresh_status": "failed" + } + ) + MESSAGE_DELIMITER + + # Generate user-readable response using AI + step_num = "4" if is_schema_modifying else "3" + step = {"type": "reasoning_step", "message": f"Step {step_num}: Generating user-friendly response"} + yield json.dumps(step) + MESSAGE_DELIMITER + + response_agent = ResponseFormatterAgent() + user_readable_response = response_agent.format_response( + user_query=queries_history[-1] if queries_history else "Destructive operation", + sql_query=sql_query, + query_results=query_results, + db_description=db_description + ) + + yield json.dumps( + { + "type": "ai_response", + "message": user_readable_response, + } + ) + MESSAGE_DELIMITER + + except Exception as e: + logging.error("Error executing confirmed SQL query: %s", e) + yield json.dumps( + {"type": "error", "message": f"Error executing query: {str(e)}"} + ) + MESSAGE_DELIMITER + else: + # User cancelled or provided invalid confirmation + yield json.dumps( + { + "type": "operation_cancelled", + "message": "Operation cancelled. The destructive SQL query was not executed." + } + ) + MESSAGE_DELIMITER + + return Response(stream_with_context(generate_confirmation()), content_type="application/json") + +@app.route("/login") +def login_google(): + if not google.authorized: + return redirect(url_for("google.login")) + + try: + resp = google.get("/oauth2/v2/userinfo") + if resp.ok: + google_user = resp.json() + # Normalize user info structure + user_info = { + "id": google_user.get("id"), + "name": google_user.get("name"), + "email": google_user.get("email"), + "picture": google_user.get("picture"), + "provider": "google" + } + session["user_info"] = user_info + session["token_validated_at"] = time.time() + return redirect(url_for("home")) + else: + # OAuth token might be expired, redirect to login + session.clear() + return redirect(url_for("google.login")) + except Exception as e: + logging.error("Google login error: %s", e) + session.clear() + return redirect(url_for("google.login")) + + + + +@app.route("/logout") +def logout(): + session.clear() + + # Revoke Google OAuth token if authorized + if google.authorized: + try: + google.get( + "https://accounts.google.com/o/oauth2/revoke", + params={"token": google.access_token} + ) + except Exception as e: + logging.warning("Error revoking Google token: %s", e) + + # Revoke GitHub OAuth token if authorized + if github.authorized: + try: + # GitHub doesn't have a simple revoke endpoint like Google + # The token will expire naturally or can be revoked from GitHub settings + pass + except Exception as e: + logging.warning("Error with GitHub token cleanup: %s", e) + + return redirect(url_for("home")) + +@app.route("/graphs//refresh", methods=["POST"]) +@token_required +def refresh_graph_schema(graph_id: str): """ - This route returns 3 random suggestions from the examples data for the chat interface. - It takes graph_id as a query parameter and returns examples specific to that graph. - If no examples exist for the graph, returns an empty list. + Manually refresh the graph schema from the database. + This endpoint allows users to manually trigger a schema refresh + if they suspect the graph is out of sync with the database. """ + graph_id = g.user_id + "_" + graph_id.strip() + try: - # Get graph_id from query parameters - graph_id = request.args.get('graph_id', '') + # Get database connection details + db_description, db_url = get_db_description(graph_id) - if not graph_id: - return jsonify([]), 400 - - # Check if graph has specific examples - if graph_id in EXAMPLES: - graph_examples = EXAMPLES[graph_id] - # Return up to 3 examples, or all if less than 3 - suggestion_questions = random.sample(graph_examples, min(3, len(graph_examples))) - return jsonify(suggestion_questions) + if not db_url or db_url == "No URL available for this database.": + return jsonify({ + "success": False, + "error": "No database URL found for this graph" + }), 400 + + # Perform schema refresh + success, message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if success: + return jsonify({ + "success": True, + "message": f"Graph schema refreshed successfully. {message}" + }), 200 + else: + return jsonify({ + "success": False, + "error": f"Failed to refresh schema: {message}" + }), 500 + + except Exception as e: + logging.error("Error in manual schema refresh: %s", e) + return jsonify({ + "success": False, + "error": f"Error refreshing schema: {str(e)}" + }), 500 + +@app.route("/database", methods=["POST"]) +@token_required +def connect_database(): + """ + Accepts a JSON payload with a Postgres URL and attempts to connect. + Returns success or error message. + """ + data = request.get_json() + url = data.get("url") if data else None + if not url: + return jsonify({"success": False, "error": "No URL provided"}), 400 + try: + # Check for Postgres URL + if url.startswith("postgres://") or url.startswith("postgresql://"): + try: + # Attempt to connect/load using the loader + success, result = PostgresLoader.load(g.user_id, url) + if success: + return jsonify({"success": True, "message": result}), 200 + else: + return jsonify({"success": False, "error": result}), 400 + except Exception as e: + return jsonify({"success": False, "error": str(e)}), 500 else: - # If graph doesn't exist in EXAMPLES, return empty list - return jsonify([]) - + return jsonify({"success": False, "error": "Invalid Postgres URL"}), 400 except Exception as e: - logging.error(f"Error fetching suggestions: {e}") - return jsonify([]), 500 + return jsonify({"success": False, "error": str(e)}), 500 + if __name__ == "__main__": app.register_blueprint(main) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 72d2b20b..d6418382 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -1,7 +1,11 @@ +"""Base loader module providing abstract base class for data loaders.""" + from abc import ABC from typing import Tuple + class BaseLoader(ABC): + """Abstract base class for data loaders.""" @staticmethod def load(_graph_id: str, _data) -> Tuple[bool, str]: diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index f0913e91..8beda5a5 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -1,38 +1,54 @@ -from typing import Tuple, Dict, List +"""CSV loader module for processing CSV files and generating database schemas.""" + import io -# import pandas as pd -import tqdm from collections import defaultdict -from litellm import embedding +from typing import Tuple + +import tqdm + from api.loaders.base_loader import BaseLoader -from api.extensions import db from api.loaders.graph_loader import load_to_graph class CSVLoader(BaseLoader): + """CSV data loader for processing CSV files and loading them into graph database.""" + @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: """ Load the data dictionary CSV file into the graph database. - + Args: graph_id: The ID of the graph to load the data into data: CSV file - + Returns: Tuple of (success, message) """ raise NotImplementedError("CSVLoader is not implemented yet") + import pandas as pd + try: # Parse CSV data using pandas for better handling of large files - df = pd.read_csv(io.StringIO(data), encoding='utf-8') + df = pd.read_csv(io.StringIO(data), encoding="utf-8") # Check if required columns exist - required_columns = ['Schema', 'Domain', 'Field', 'Type', 'Description', 'Related', 'Cardinality'] + required_columns = [ + "Schema", + "Domain", + "Field", + "Type", + "Description", + "Related", + "Cardinality", + ] missing_columns = [col for col in required_columns if col not in df.columns] - + if missing_columns: - return False, f"Missing required columns in CSV: {', '.join(missing_columns)}" + return ( + False, + f"Missing required columns in CSV: {', '.join(missing_columns)}", + ) db_name = """Abacus Domain Model 25.3.5 The Abacus Domain Model is a physical manifestation of the hierarchical object model that Abacus Insights uses to store data. (It is not a relational database.) It is a foundational aspect of @@ -41,160 +57,188 @@ def load(graph_id: str, data) -> Tuple[bool, str]: The Abacus Domain Model is organized into schemas, which group related domains. We implement each domain as a broad structure with minimal nesting. The model avoids inheritance and deep nesting to minimize complexity and optimize performance.""" - # Process data by grouping by Schema and Domain to identify tables # Group by Schema and Domain to get tables - tables = defaultdict(lambda: { - 'description': '', - 'columns': {}, - # 'relationships': [], - 'col_descriptions': [] - }) - - rel_table = defaultdict(lambda: { - 'primary_key_table': '', - 'fk_tables': [] - }) + tables = defaultdict( + lambda: { + "description": "", + "columns": {}, + # 'relationships': [], + "col_descriptions": [], + } + ) + + rel_table = defaultdict(lambda: {"primary_key_table": "", "fk_tables": []}) relationships = {} # First pass: Organize data into tables for idx, row in tqdm.tqdm(df.iterrows(), total=len(df), desc="Organizing data"): - schema = row['Schema'] - domain = row['Domain'] + schema = row["Schema"] + domain = row["Domain"] table_name = f"{schema}.{domain}" - + # Set table description (use Domain Description if available) - if 'Domain Description' in row and not pd.isna(row['Domain Description']) and not tables[table_name]['description']: - tables[table_name]['description'] = row['Domain Description'] - + if ( + "Domain Description" in row + and not pd.isna(row["Domain Description"]) + and not tables[table_name]["description"] + ): + tables[table_name]["description"] = row["Domain Description"] + # Add column information - field = row['Field'] - field_type = row['Type'] if not pd.isna(row['Type']) else 'STRING' - field_desc = row['Description'] if not pd.isna(row['Description']) else field - + field = row["Field"] + field_type = row["Type"] if not pd.isna(row["Type"]) else "STRING" + field_desc = row["Description"] if not pd.isna(row["Description"]) else field + nullable = True # Default to nullable since we don't have explicit null info if not pd.isna(field): - tables[table_name]['col_descriptions'].append(field_desc) - tables[table_name]['columns'][field] = { - 'type': field_type, - 'description': field_desc, - 'null': nullable, - 'key': 'PRI' if field.lower().endswith('_id') else '', # Assumption: *_id fields are primary keys - 'default': '', - 'extra': '' + tables[table_name]["col_descriptions"].append(field_desc) + tables[table_name]["columns"][field] = { + "type": field_type, + "description": field_desc, + "null": nullable, + "key": ( + "PRI" if field.lower().endswith("_id") else "" + ), # Assumption: *_id fields are primary keys + "default": "", + "extra": "", } - + # Add relationship information if available - if not pd.isna(row['Related']) and not pd.isna(row['Cardinality']): + if not pd.isna(row["Related"]) and not pd.isna(row["Cardinality"]): source_field = field - target_table = row['Related'] + target_table = row["Related"] # cardinality = row['Cardinality'] if table_name not in relationships: relationships[table_name] = [] - relationships[table_name].append({"from": table_name, - "to": target_table, - "source_column": source_field, - "target_column": df.to_dict("records")[idx+1]['Array Field'] if not pd.isna(df.to_dict("records")[idx+1]['Array Field']) else '', - "note": ""}) - + relationships[table_name].append( + { + "from": table_name, + "to": target_table, + "source_column": source_field, + "target_column": ( + df.to_dict("records")[idx + 1]["Array Field"] + if not pd.isna(df.to_dict("records")[idx + 1]["Array Field"]) + else "" + ), + "note": "", + } + ) + # tables[table_name]['relationships'].append({ # 'source_field': source_field, # 'target_table': target_table, # 'cardinality': cardinality, - # 'target_field': df.to_dict("records")[idx+1]['Array Field'] if not pd.isna(df.to_dict("records")[idx+1]['Array Field']) else '' + # 'target_field': df.to_dict("records")[idx+1]['Array Field'] \ + # if not pd.isna(df.to_dict("records")[idx+1] \ + # ['Array Field']) else '' # }) tables[target_table]["description"] = field_desc else: - field = row['Array Field'] + field = row["Array Field"] field_desc = field_desc if not pd.isna(field_desc) else field # if len(tables[target_table]['col_descriptions']) == 0: # tables[table_name]['relationships'][-1]['target_field'] = field - tables[target_table]['col_descriptions'].append(field_desc) - tables[target_table]['columns'][field] = { - 'type': field_type, - 'description': field_desc, - 'null': nullable, - 'key': 'PRI' if field.lower().endswith('_id') else '', # Assumption: *_id fields are primary keys - 'default': '', - 'extra': '' + tables[target_table]["col_descriptions"].append(field_desc) + tables[target_table]["columns"][field] = { + "type": field_type, + "description": field_desc, + "null": nullable, + "key": ( + "PRI" if field.lower().endswith("_id") else "" + ), # Assumption: *_id fields are primary keys + "default": "", + "extra": "", } - if field.endswith('_id'): - if len(tables[table_name]['columns']) == 1 and field.endswith('_id'): + if field.endswith("_id"): + if len(tables[table_name]["columns"]) == 1 and field.endswith("_id"): suspected_primary_key = field[:-3] if suspected_primary_key in domain: - rel_table[field]['primary_key_table'] = table_name + rel_table[field]["primary_key_table"] = table_name else: - rel_table[field]['fk_tables'].append(table_name) + rel_table[field]["fk_tables"].append(table_name) else: - rel_table[field]['fk_tables'].append(table_name) + rel_table[field]["fk_tables"].append(table_name) - for key, tables_info in tqdm.tqdm(rel_table.items(), desc="Creating relationships from names"): - if len(tables_info['fk_tables']) > 0: - fk_tables = list(set(tables_info['fk_tables'])) - if len(tables_info['primary_key_table']) > 0: + for key, tables_info in tqdm.tqdm( + rel_table.items(), desc="Creating relationships from names" + ): + if len(tables_info["fk_tables"]) > 0: + fk_tables = list(set(tables_info["fk_tables"])) + if len(tables_info["primary_key_table"]) > 0: for table in fk_tables: if table not in relationships: relationships[table_name] = [] - relationships[table].append({"from": table, - "to": tables_info['primary_key_table'], - "source_column": key, - "target_column": key, - "note": 'many-one'}) + relationships[table].append( + { + "from": table, + "to": tables_info["primary_key_table"], + "source_column": key, + "target_column": key, + "note": "many-one", + } + ) else: for table_1 in fk_tables: for table_2 in fk_tables: if table_1 != table_2: if table_1 not in relationships: relationships[table_1] = [] - relationships[table_1].append({"from": table_1, - "to": table_2, - "source_column": key, - "target_column": key, - "note": 'many-many'}) - + relationships[table_1].append( + { + "from": table_1, + "to": table_2, + "source_column": key, + "target_column": key, + "note": "many-many", + } + ) + load_to_graph(graph_id, tables, relationships, db_name=db_name) return True, "Data dictionary loaded successfully into graph" - + except Exception as e: return False, f"Error loading CSV: {str(e)}" - # else: - # # For case 2: when no primary key table exists, connect all FK tables to each other - # graph.query( - # """ - # CREATE (src: Column {name: $col, cardinality: $cardinality}) - # """, - # { - # 'col': key, - # 'cardinality': 'many-many' - # } - # ) - # for i in range(len(fk_tables)): - # graph.query( - # """ - # MATCH (src:Column {name: $source_col}) - # -[:BELONGS_TO]->(source:Table {name: $source_table}) - # MATCH (tgt:Column {name: $target_col, cardinality: $cardinality}) - # CREATE (src)-[:REFERENCES { - # constraint_name: $fk_name, - # cardinality: $cardinality - # }]->(tgt) - # """, - # { - # 'source_col': key, - # 'target_col': key, - # 'source_table': fk_tables[i], - # 'fk_name': key, - # 'cardinality': 'many-many' - # } - # ) + # else: + # # For case 2: when no primary key table exists, \ + # # connect all FK tables to each other + # graph.query( + # """ + # CREATE (src: Column {name: $col, cardinality: $cardinality}) + # """, + # { + # 'col': key, + # 'cardinality': 'many-many' + # } + # ) + # for i in range(len(fk_tables)): + # graph.query( + # """ + # MATCH (src:Column {name: $source_col}) + # -[:BELONGS_TO]->(source:Table {name: $source_table}) + # MATCH (tgt:Column {name: $target_col, cardinality: $cardinality}) + # CREATE (src)-[:REFERENCES { + # constraint_name: $fk_name, + # cardinality: $cardinality + # }]->(tgt) + # """, + # { + # 'source_col': key, + # 'target_col': key, + # 'source_table': fk_tables[i], + # 'fk_name': key, + # 'cardinality': 'many-many' + # } + # ) + # # Second pass: Create table nodes # for table_name, table_info in tqdm.tqdm(tables.items(), desc="Creating Table nodes"): # # Skip if no columns (probably just a reference) # if not table_info['columns']: # continue - + # # Generate embedding for table description # table_desc = table_info['description'] # embedding_result = client.models.embed_content( @@ -206,7 +250,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # graph.query( # """ # CREATE (t:Table { -# name: $table_name, +# name: $table_name, # description: $description, # embedding: vecf32($embedding) # }) @@ -222,25 +266,33 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # batch_size = 50 # col_descriptions = table_info['col_descriptions'] # for batch in tqdm.tqdm( -# [col_descriptions[i:i + batch_size] for i in range(0, len(col_descriptions), batch_size)], +# [col_descriptions[i:i + batch_size] \ +# for i in range(0, len(col_descriptions), batch_size)], # desc=f"Creating embeddings for {table_name}"): - -# embedding_result = embedding(model='bedrock/cohere.embed-english-v3', input=batch[:95], aws_profile_name=Config.AWS_PROFILE, aws_region_name=Config.AWS_REGION) + +# embedding_result = embedding( +# model='bedrock/cohere.embed-english-v3', +# input=batch[:95], +# aws_profile_name=Config.AWS_PROFILE, +# aws_region_name=Config.AWS_REGION) # embed_columns.extend([emb.values for emb in embedding_result.embeddings]) # except Exception as e: # print(f"Error creating embeddings: {str(e)}") - + # # Create column nodes -# for idx, (col_name, col_info) in tqdm.tqdm(enumerate(table_info['columns'].items()), desc=f"Creating columns for {table_name}", total=len(table_info['columns'])): +# for idx, (col_name, col_info) in tqdm.tqdm( +# enumerate(table_info['columns'].items()), +# desc=f"Creating columns for {table_name}", +# total=len(table_info['columns'])): # # embedding_result = embedding( # # model=Config.EMBEDDING_MODEL, # # input=[col_info['description'] if col_info['description'] else col_name] # # ) - + # ## Temp # # agent_tax = TaxonomyAgent() # # tax = agent_tax.get_answer(col_name, col_info) -# # # +# # # # graph.query( # """ # MATCH (t:Table {name: $table_name}) @@ -267,17 +319,23 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # 'embedding': embed_columns[idx] # } # ) - + # # Third pass: Create relationships -# for table_name, table_info in tqdm.tqdm(tables.items(), desc="Creating relationships"): +# for table_name, table_info in tqdm.tqdm(tables.items(), \ +# desc="Creating relationships"): # for rel in table_info['relationships']: # source_field = rel['source_field'] # target_table = rel['target_table'] # cardinality = rel['cardinality'] -# target_field = rel['target_field']#list(tables[tables[table_name]['relationships'][-1]['target_table']]['columns'].keys())[0] +# target_field = rel['target_field'] # \ +# # list(tables[tables[table_name]['relationships'][-1] \ +# # ['target_table']]['columns'].keys())[0] # # Create constraint name -# constraint_name = f"fk_{table_name.replace('.', '_')}_{source_field}_to_{target_table.replace('.', '_')}" - +# constraint_name = ( +# f"fk_{table_name.replace('.', '_')}_{source_field}_to_" +# f"{target_table.replace('.', '_')}" +# ) + # # Create relationship if both tables and columns exist # try: # graph.query( @@ -303,7 +361,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # except Exception as e: # print(f"Warning: Could not create relationship: {str(e)}") # continue -# for key, tables_info in tqdm.tqdm(rel_table.items(), desc="Creating relationships from names"): +# for key, tables_info in tqdm.tqdm(rel_table.items(), \ +# desc="Creating relationships from names"): # if len(tables_info['fk_tables']) > 0: # fk_tables = list(set(tables_info['fk_tables'])) # if len(tables_info['primary_key_table']) > 0: @@ -329,7 +388,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # } # ) # else: -# # For case 2: when no primary key table exists, connect all FK tables to each other +# # For case 2: when no primary key table exists, \ +# # connect all FK tables to each other # graph.query( # """ # CREATE (src: Column {name: $col, cardinality: $cardinality}) @@ -359,15 +419,15 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # } # ) - # load_to_graph(graph_id, entities, relationships, db_name="ERP system") - # return True, "Data dictionary loaded successfully into graph" - - # except Exception as e: - # return False, f"Error loading CSV: {str(e)}" +# load_to_graph(graph_id, entities, relationships, db_name="ERP system") +# return True, "Data dictionary loaded successfully into graph" + +# except Exception as e: +# return False, f"Error loading CSV: {str(e)}" # if __name__ == "__main__": # # Example usage # loader = CSVLoader() # success, message = loader.load("my_graph", "Data Dictionary.csv") -# print(message) \ No newline at end of file +# print(message) diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index eb551d7b..48f1f676 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -1,14 +1,26 @@ +"""Graph loader module for loading data into graph databases.""" + +import json + import tqdm + from api.config import Config from api.extensions import db from api.utils import generate_db_description -import json -def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size: int=100, db_name: str="TBD") -> None: + +def load_to_graph( + graph_id: str, + entities: dict, + relationships: dict, + batch_size: int = 100, + db_name: str = "TBD", + db_url: str = "", +) -> None: """ Load the graph data into the database. It gets the Graph name as an argument and expects - + Input: - entities: A dictionary containing the entities and their attributes. - relationships: A dictionary containing the relationships between entities. @@ -19,51 +31,43 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size embedding_model = Config.EMBEDDING_MODEL vec_len = embedding_model.get_vector_size() - try: + try: # Create vector indices - graph.query(""" + graph.query( + """ CREATE VECTOR INDEX FOR (t:Table) ON (t.embedding) OPTIONS {dimension:$size, similarityFunction:'euclidean'} """, - { - 'size': vec_len - }) - - graph.query(""" + {"size": vec_len}, + ) + + graph.query( + """ CREATE VECTOR INDEX FOR (c:Column) ON (c.embedding) OPTIONS {dimension:$size, similarityFunction:'euclidean'} """, - { - 'size': vec_len - }) + {"size": vec_len}, + ) graph.query("CREATE INDEX FOR (p:Table) ON (p.name)") except Exception as e: print(f"Error creating vector indices: {str(e)}") - - - db_des = generate_db_description( - db_name=db_name, - table_names=list(entities.keys()) - ) + db_des = generate_db_description(db_name=db_name, table_names=list(entities.keys())) graph.query( """ CREATE (d:Database { name: $db_name, - description: $description + description: $description, + url: $url }) """, - { - 'db_name': db_name, - 'description': db_des - } + {"db_name": db_name, "description": db_des, "url": db_url}, ) - for table_name, table_info in tqdm.tqdm(entities.items(), desc="Creating Graph Table Nodes"): - table_desc = table_info['description'] + table_desc = table_info["description"] embedding_result = embedding_model.embed(table_desc) - fk = json.dumps(table_info.get('foreign_keys', [])) + fk = json.dumps(table_info.get("foreign_keys", [])) # Create table node graph.query( @@ -76,38 +80,46 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size }) """, { - 'table_name': table_name, - 'description': table_desc, - 'embedding': embedding_result[0], - 'foreign_keys': fk - } + "table_name": table_name, + "description": table_desc, + "embedding": embedding_result[0], + "foreign_keys": fk, + }, ) - # Batch embeddings for table columns - # TODO: Check if the embedding model and description are correct (without 2 sources of truth) + # TODO: Check if the embedding model and description are correct \ + # (without 2 sources of truth) batch_flag = True - col_descriptions = table_info.get('col_descriptions') + col_descriptions = table_info.get("col_descriptions") if col_descriptions is None: batch_flag = False else: try: embed_columns = [] for batch in tqdm.tqdm( - [col_descriptions[i:i + batch_size] for i in range(0, len(col_descriptions), batch_size)], - desc=f"Creating embeddings for {table_name} columns",): - + [ + col_descriptions[i : i + batch_size] + for i in range(0, len(col_descriptions), batch_size) + ], + desc=f"Creating embeddings for {table_name} columns", + ): + embedding_result = embedding_model.embed(batch) embed_columns.extend(embedding_result) except Exception as e: print(f"Error creating embeddings: {str(e)}") batch_flag = False - + # Create column nodes - for idx, (col_name, col_info) in tqdm.tqdm(enumerate(table_info['columns'].items()), desc=f"Creating Graph Columns for {table_name}", total=len(table_info['columns'])): + for idx, (col_name, col_info) in tqdm.tqdm( + enumerate(table_info["columns"].items()), + desc=f"Creating Graph Columns for {table_name}", + total=len(table_info["columns"]), + ): if not batch_flag: embed_columns = [] - embedding_result = embedding_model.embed(col_info['description']) + embedding_result = embedding_model.embed(col_info["description"]) embed_columns.extend(embedding_result) idx = 0 @@ -124,25 +136,27 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size })-[:BELONGS_TO]->(t) """, { - 'table_name': table_name, - 'col_name': col_name, - 'type': col_info.get('type', 'unknown'), - 'nullable': col_info.get('null', 'unknown'), - 'key': col_info.get('key', 'unknown'), - 'description': col_info['description'], - 'embedding': embed_columns[idx] - } + "table_name": table_name, + "col_name": col_name, + "type": col_info.get("type", "unknown"), + "nullable": col_info.get("null", "unknown"), + "key": col_info.get("key", "unknown"), + "description": col_info["description"], + "embedding": embed_columns[idx], + }, ) - + # Create relationships - for rel_name, table_info in tqdm.tqdm(relationships.items(), desc="Creating Graph Table Relationships"): + for rel_name, table_info in tqdm.tqdm( + relationships.items(), desc="Creating Graph Table Relationships" + ): for rel in table_info: - source_table = rel['from'] - source_field = rel['source_column'] - target_table = rel['to'] - target_field = rel['target_column'] - note = rel.get('note', '') - + source_table = rel["from"] + source_field = rel["source_column"] + target_table = rel["to"] + target_field = rel["target_column"] + note = rel.get("note", "") + # Create relationship if both tables and columns exist try: graph.query( @@ -157,14 +171,14 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size }]->(tgt) """, { - 'source_col': source_field, - 'target_col': target_field, - 'source_table': source_table, - 'target_table': target_table, - 'rel_name': rel_name, - 'note': note - } + "source_col": source_field, + "target_col": target_field, + "source_table": source_table, + "target_table": target_table, + "rel_name": rel_name, + "note": note, + }, ) except Exception as e: print(f"Warning: Could not create relationship: {str(e)}") - continue \ No newline at end of file + continue diff --git a/api/loaders/json_loader.py b/api/loaders/json_loader.py index 853e0809..0b74ec60 100644 --- a/api/loaders/json_loader.py +++ b/api/loaders/json_loader.py @@ -1,24 +1,27 @@ -from typing import Tuple +"""JSON loader module for processing JSON schema files.""" + import json +from typing import Tuple + import tqdm -from jsonschema import ValidationError, validate -from litellm import embedding +from jsonschema import ValidationError + from api.config import Config from api.loaders.base_loader import BaseLoader -from api.extensions import db -from api.utils import generate_db_description from api.loaders.graph_loader import load_to_graph from api.loaders.schema_validator import validate_table_schema try: - with open(Config.SCHEMA_PATH, 'r', encoding='utf-8') as f: + with open(Config.SCHEMA_PATH, "r", encoding="utf-8") as f: schema = json.load(f) except FileNotFoundError as exc: raise FileNotFoundError(f"Schema file not found: {Config.SCHEMA_PATH}") from exc except json.JSONDecodeError as exc: raise ValueError(f"Invalid schema JSON: {str(exc)}") from exc + class JSONLoader(BaseLoader): + """JSON schema loader for loading database schemas from JSON files.""" @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: @@ -37,22 +40,32 @@ def load(graph_id: str, data) -> Tuple[bool, str]: print("❌ Schema validation failed with the following issues:") for error in validation_errors: print(f" - {error}") - raise ValidationError("Schema validation failed. Please check the schema and try again.") + raise ValidationError( + "Schema validation failed. Please check the schema and try again." + ) except ValidationError as exc: return False, str(exc) - + relationships = {} - for table_name, table_info in tqdm.tqdm(data['tables'].items(), "Create Table relationships"): + for table_name, table_info in tqdm.tqdm( + data["tables"].items(), "Create Table relationships" + ): # Create Foreign Key relationships - for fk_name, fk_info in tqdm.tqdm(table_info['foreign_keys'].items(), "Create Foreign Key relationships"): + for fk_name, fk_info in tqdm.tqdm( + table_info["foreign_keys"].items(), "Create Foreign Key relationships" + ): if table_name not in relationships: relationships[table_name] = [] - relationships[table_name].append({"from": table_name, - "to": fk_info['referenced_table'], - "source_column": fk_info['column'], - "target_column": fk_info['referenced_column'], - "note": fk_name}) - load_to_graph(graph_id, data['tables'], relationships, db_name=data['database']) - - return True, "Graph loaded successfully" \ No newline at end of file + relationships[table_name].append( + { + "from": table_name, + "to": fk_info["referenced_table"], + "source_column": fk_info["column"], + "target_column": fk_info["referenced_column"], + "note": fk_name, + } + ) + load_to_graph(graph_id, data["tables"], relationships, db_name=data["database"]) + + return True, "Graph loaded successfully" diff --git a/api/loaders/odata_loader.py b/api/loaders/odata_loader.py index 762fe525..77125d4d 100644 --- a/api/loaders/odata_loader.py +++ b/api/loaders/odata_loader.py @@ -1,9 +1,10 @@ import re -from typing import Tuple import xml.etree.ElementTree as ET +from typing import Tuple + import tqdm + from api.loaders.base_loader import BaseLoader -from api.extensions import db from api.loaders.graph_loader import load_to_graph @@ -14,7 +15,7 @@ class ODataLoader(BaseLoader): @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: - """ Load XML ODATA schema into a Graph. """ + """Load XML ODATA schema into a Graph.""" try: # Parse the OData schema @@ -38,50 +39,60 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: # Define namespaces namespaces = { - 'edmx': "http://docs.oasis-open.org/odata/ns/edmx", - 'edm': "http://docs.oasis-open.org/odata/ns/edm" + "edmx": "http://docs.oasis-open.org/odata/ns/edmx", + "edm": "http://docs.oasis-open.org/odata/ns/edm", } schema_element = root.find(".//edmx:DataServices/edm:Schema", namespaces) if schema_element is None: raise ET.ParseError("Schema element not found") - + entity_types = schema_element.findall("edm:EntityType", namespaces) for entity_type in tqdm.tqdm(entity_types, "Parsing OData schema"): entity_name = entity_type.get("Name") - entities[entity_name] = {'col_descriptions': []} + entities[entity_name] = {"col_descriptions": []} entities[entity_name]["columns"] = {} for prop in entity_type.findall("edm:Property", namespaces): prop_name = prop.get("Name") try: if prop_name is not None: entities[entity_name]["columns"][prop_name] = {} - entities[entity_name]["columns"][prop_name]["type"] = prop.get("Type").split(".")[-1] + entities[entity_name]["columns"][prop_name]["type"] = prop.get( + "Type" + ).split(".")[-1] col_des = entity_name if len(prop.findall("edm:Annotation", namespaces)) > 0: if len(prop.findall("edm:Annotation", namespaces)[0].get("String")) > 0: - col_des = prop.findall("edm:Annotation", namespaces)[0].get("String") + col_des = prop.findall("edm:Annotation", namespaces)[0].get( + "String" + ) entities[entity_name]["col_descriptions"].append(col_des) entities[entity_name]["columns"][prop_name]["description"] = col_des except Exception as e: print(f"Error parsing property {prop_name} for entity {entity_name}") continue - # = {prop.get("Name"): prop.get("Type") for prop in entity_type.findall("edm:Property", namespaces)} + # = {prop.get("Name"): prop.get("Type") \ + # for prop in entity_type.findall("edm:Property", namespaces)} description = entity_type.findall("edm:Annotation", namespaces) if len(description) > 0: - entities[entity_name]["description"] = description[0].get("String").replace("'", "\\'") + entities[entity_name]["description"] = ( + description[0].get("String").replace("'", "\\'") + ) else: try: - entities[entity_name]["description"] = entity_name + " with Primery key: " + entity_type.find("edm:Key/edm:PropertyRef", namespaces).attrib['Name'] + entities[entity_name]["description"] = ( + entity_name + + " with Primery key: " + + entity_type.find("edm:Key/edm:PropertyRef", namespaces).attrib["Name"] + ) except: print(f"Error parsing description for entity {entity_name}") entities[entity_name]["description"] = entity_name - for entity_type in tqdm.tqdm(entity_types, "Parsing OData schema - relationships"): - entity_name = entity_type.attrib['Name'] + entity_name = entity_type.attrib["Name"] for rel in entity_type.findall("edm:NavigationProperty", namespaces): rel_name = rel.get("Name") @@ -89,39 +100,43 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: # Clean 'Collection(...)' wrapper if exists if raw_type.startswith("Collection(") and raw_type.endswith(")"): - raw_type = raw_type[len("Collection("):-1] + raw_type = raw_type[len("Collection(") : -1] # Extract the target entity name - match = re.search(r'(\w+)$', raw_type) + match = re.search(r"(\w+)$", raw_type) target_entity = match.group(1) if match else "UNKNOWN" - - source_entity = entity_name target_entity = target_entity source_fields = entities.get(entity_name, {})["columns"] target_fields = entities.get(target_entity, {})["columns"] - #TODO This usage is for demonstration purposes only, it should be replaced with a more robust method + # TODO This usage is for demonstration purposes only, it should be \ + # replaced with a more robust method source_col, target_col = guess_relationship_columns(source_fields, target_fields) if source_col and target_col: # Store the relationship if rel_name not in relationships: relationships[rel_name] = [] - # src_col, tgt_col = guess_relationship_columns(source_entity, target_entity, entities[source_entity], entities[target_entity]) - relationships[rel_name].append({ - "from": source_entity, - "to": target_entity, - "source_column": source_col, - "target_column": target_col, - "note": "inferred" if source_col and target_col else "implicit/subform" - }) - + # src_col, tgt_col = guess_relationship_columns(source_entity, \ + # target_entity, entities[source_entity], entities[target_entity]) + relationships[rel_name].append( + { + "from": source_entity, + "to": target_entity, + "source_column": source_col, + "target_column": target_col, + "note": ( + "inferred" if source_col and target_col else "implicit/subform" + ), + } + ) return entities, relationships -#TODO: this funtion is for demonstration purposes only, it should be replaced with a more robust method +# TODO: this funtion is for demonstration purposes only, it should be \ +# replaced with a more robust method def guess_relationship_columns(source_fields, target_fields): for src_key, src_meta in source_fields.items(): if src_key == "description": diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 0bb67160..a12e3ac2 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -1,25 +1,70 @@ from typing import Tuple, Dict, Any, List -import psycopg2 +import re +import logging import tqdm -from api.config import Config +import psycopg2 +import datetime +import decimal from api.loaders.base_loader import BaseLoader -from api.extensions import db -from api.utils import generate_db_description from api.loaders.graph_loader import load_to_graph +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -class PostgreSQLLoader(BaseLoader): + +class PostgresLoader(BaseLoader): """ Loader for PostgreSQL databases that connects and extracts schema information. """ + # DDL operations that modify database schema + SCHEMA_MODIFYING_OPERATIONS = { + 'CREATE', 'ALTER', 'DROP', 'RENAME', 'TRUNCATE' + } + + # More specific patterns for schema-affecting operations + SCHEMA_PATTERNS = [ + r'^\s*CREATE\s+TABLE', + r'^\s*CREATE\s+INDEX', + r'^\s*CREATE\s+UNIQUE\s+INDEX', + r'^\s*ALTER\s+TABLE', + r'^\s*DROP\s+TABLE', + r'^\s*DROP\s+INDEX', + r'^\s*RENAME\s+TABLE', + r'^\s*TRUNCATE\s+TABLE', + r'^\s*CREATE\s+VIEW', + r'^\s*DROP\s+VIEW', + r'^\s*CREATE\s+SCHEMA', + r'^\s*DROP\s+SCHEMA', + ] + + @staticmethod + def _serialize_value(value): + """ + Convert non-JSON serializable values to JSON serializable format. + + Args: + value: The value to serialize + + Returns: + JSON serializable version of the value + """ + if isinstance(value, (datetime.date, datetime.datetime)): + return value.isoformat() + elif isinstance(value, datetime.time): + return value.isoformat() + elif isinstance(value, decimal.Decimal): + return float(value) + elif value is None: + return None + else: + return value + @staticmethod - def load(graph_id: str, connection_url: str) -> Tuple[bool, str]: + def load(prefix: str, connection_url: str) -> Tuple[bool, str]: """ Load the graph data from a PostgreSQL database into the graph database. Args: - graph_id: The ID of the graph to load data into connection_url: PostgreSQL connection URL in format: postgresql://username:password@host:port/database @@ -30,27 +75,27 @@ def load(graph_id: str, connection_url: str) -> Tuple[bool, str]: # Connect to PostgreSQL database conn = psycopg2.connect(connection_url) cursor = conn.cursor() - + # Extract database name from connection URL db_name = connection_url.split('/')[-1] if '?' in db_name: db_name = db_name.split('?')[0] - + # Get all table information - entities = PostgreSQLLoader.extract_tables_info(cursor) - + entities = PostgresLoader.extract_tables_info(cursor) + # Get all relationship information - relationships = PostgreSQLLoader.extract_relationships(cursor) - + relationships = PostgresLoader.extract_relationships(cursor) + # Close database connection cursor.close() conn.close() - + # Load data into graph - load_to_graph(graph_id, entities, relationships, db_name=db_name) - + load_to_graph(prefix + "_" + db_name, entities, relationships, db_name=db_name, db_url=connection_url) + return True, f"PostgreSQL schema loaded successfully. Found {len(entities)} tables." - + except psycopg2.Error as e: return False, f"PostgreSQL connection error: {str(e)}" except Exception as e: @@ -68,7 +113,7 @@ def extract_tables_info(cursor) -> Dict[str, Any]: Dict containing table information """ entities = {} - + # Get all tables in public schema cursor.execute(""" SELECT table_name, table_comment @@ -84,31 +129,31 @@ def extract_tables_info(cursor) -> Dict[str, Any]: AND t.table_type = 'BASE TABLE' ORDER BY t.table_name; """) - + tables = cursor.fetchall() - + for table_name, table_comment in tqdm.tqdm(tables, desc="Extracting table information"): table_name = table_name.strip() - + # Get column information for this table - columns_info = PostgreSQLLoader.extract_columns_info(cursor, table_name) - + columns_info = PostgresLoader.extract_columns_info(cursor, table_name) + # Get foreign keys for this table - foreign_keys = PostgreSQLLoader.extract_foreign_keys(cursor, table_name) - + foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name) + # Generate table description table_description = table_comment if table_comment else f"Table: {table_name}" - + # Get column descriptions for batch embedding col_descriptions = [col_info['description'] for col_info in columns_info.values()] - + entities[table_name] = { 'description': table_description, 'columns': columns_info, 'foreign_keys': foreign_keys, 'col_descriptions': col_descriptions } - + return entities @staticmethod @@ -159,29 +204,29 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: AND c.table_schema = 'public' ORDER BY c.ordinal_position; """, (table_name, table_name, table_name)) - + columns = cursor.fetchall() columns_info = {} - + for col_name, data_type, is_nullable, column_default, key_type, column_comment in columns: col_name = col_name.strip() - + # Generate column description description_parts = [] if column_comment: description_parts.append(column_comment) else: description_parts.append(f"Column {col_name} of type {data_type}") - + if key_type != 'NONE': description_parts.append(f"({key_type})") - + if is_nullable == 'NO': description_parts.append("(NOT NULL)") - + if column_default: description_parts.append(f"(Default: {column_default})") - + columns_info[col_name] = { 'type': data_type, 'null': is_nullable, @@ -189,7 +234,7 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: 'description': ' '.join(description_parts), 'default': column_default } - + return columns_info @staticmethod @@ -221,7 +266,7 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: AND tc.table_name = %s AND tc.table_schema = 'public'; """, (table_name,)) - + foreign_keys = [] for constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall(): foreign_keys.append({ @@ -230,7 +275,7 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: 'referenced_table': foreign_table.strip(), 'referenced_column': foreign_column.strip() }) - + return foreign_keys @staticmethod @@ -262,15 +307,15 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: AND tc.table_schema = 'public' ORDER BY tc.table_name, tc.constraint_name; """) - + relationships = {} for table_name, constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall(): table_name = table_name.strip() constraint_name = constraint_name.strip() - + if constraint_name not in relationships: relationships[constraint_name] = [] - + relationships[constraint_name].append({ 'from': table_name, 'to': foreign_table.strip(), @@ -278,5 +323,159 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: 'target_column': foreign_column.strip(), 'note': f'Foreign key constraint: {constraint_name}' }) - + return relationships + + @staticmethod + def is_schema_modifying_query(sql_query: str) -> Tuple[bool, str]: + """ + Check if a SQL query modifies the database schema. + + Args: + sql_query: The SQL query to check + + Returns: + Tuple of (is_schema_modifying, operation_type) + """ + if not sql_query or not sql_query.strip(): + return False, "" + + # Clean and normalize the query + normalized_query = sql_query.strip().upper() + + # Check for basic DDL operations + first_word = normalized_query.split()[0] if normalized_query.split() else "" + if first_word in PostgresLoader.SCHEMA_MODIFYING_OPERATIONS: + # Additional pattern matching for more precise detection + for pattern in PostgresLoader.SCHEMA_PATTERNS: + if re.match(pattern, normalized_query, re.IGNORECASE): + return True, first_word + + # If it's a known DDL operation but doesn't match specific patterns, + # still consider it schema-modifying (better safe than sorry) + return True, first_word + + return False, "" + + @staticmethod + def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: + """ + Refresh the graph schema by clearing existing data and reloading from the database. + + Args: + graph_id: The graph ID to refresh + db_url: Database connection URL + + Returns: + Tuple of (success, message) + """ + try: + logging.info("Schema modification detected. Refreshing graph schema for: %s", graph_id) + + # Import here to avoid circular imports + from api.extensions import db + + # Clear existing graph data + # Drop current graph before reloading + graph = db.select_graph(graph_id) + graph.delete() + + # Extract prefix from graph_id (remove database name part) + # graph_id format is typically "prefix_database_name" + parts = graph_id.split('_') + if len(parts) >= 2: + # Reconstruct prefix by joining all parts except the last one + prefix = '_'.join(parts[:-1]) + else: + prefix = graph_id + + # Reuse the existing load method to reload the schema + success, message = PostgresLoader.load(prefix, db_url) + + if success: + logging.info("Graph schema refreshed successfully.") + return True, message + else: + return False, f"Failed to reload schema: {message}" + + except Exception as e: + error_msg = f"Error refreshing graph schema: {str(e)}" + logging.error(error_msg) + return False, error_msg + + @staticmethod + def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: + """ + Execute a SQL query on the PostgreSQL database and return the results. + + Args: + sql_query: The SQL query to execute + db_url: PostgreSQL connection URL in format: + postgresql://username:password@host:port/database + + Returns: + List of dictionaries containing the query results + """ + try: + # Connect to PostgreSQL database + conn = psycopg2.connect(db_url) + cursor = conn.cursor() + + # Execute the SQL query + cursor.execute(sql_query) + + # Check if the query returns results (SELECT queries) + if cursor.description is not None: + # This is a SELECT query or similar that returns rows + columns = [desc[0] for desc in cursor.description] + results = cursor.fetchall() + result_list = [] + for row in results: + # Serialize each value to ensure JSON compatibility + serialized_row = { + columns[i]: PostgresLoader._serialize_value(row[i]) + for i in range(len(columns)) + } + result_list.append(serialized_row) + else: + # This is an INSERT, UPDATE, DELETE, or other non-SELECT query + # Return information about the operation + affected_rows = cursor.rowcount + sql_type = sql_query.strip().split()[0].upper() + + if sql_type in ['INSERT', 'UPDATE', 'DELETE']: + result_list = [{ + "operation": sql_type, + "affected_rows": affected_rows, + "status": "success" + }] + else: + # For other types of queries (CREATE, DROP, etc.) + result_list = [{ + "operation": sql_type, + "status": "success" + }] + + # Commit the transaction for write operations + conn.commit() + + # Close database connection + cursor.close() + conn.close() + + return result_list + + except psycopg2.Error as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + raise Exception(f"PostgreSQL query execution error: {str(e)}") + except Exception as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + raise Exception(f"Error executing SQL query: {str(e)}") diff --git a/api/loaders/schema_validator.py b/api/loaders/schema_validator.py index 5ad2206f..32c54cb1 100644 --- a/api/loaders/schema_validator.py +++ b/api/loaders/schema_validator.py @@ -1,8 +1,19 @@ +"""Schema validation module for table schemas.""" REQUIRED_COLUMN_KEYS = {"description", "type", "null", "key", "default"} VALID_NULL_VALUES = {"YES", "NO"} + def validate_table_schema(schema): + """ + Validate a table schema structure. + + Args: + schema (dict): The schema dictionary to validate + + Returns: + list: List of validation errors found + """ errors = [] # Validate top-level database key @@ -15,36 +26,76 @@ def validate_table_schema(schema): return errors for table_name, table_data in schema["tables"].items(): - if not table_data.get("description"): - errors.append(f"Table '{table_name}' is missing a description") - - if "columns" not in table_data or not isinstance(table_data["columns"], dict): - errors.append(f"Table '{table_name}' has no valid 'columns' definition") - continue - - for column_name, column_data in table_data["columns"].items(): - # Check for missing required keys - missing_keys = REQUIRED_COLUMN_KEYS - column_data.keys() - if missing_keys: - errors.append(f"Column '{column_name}' in table '{table_name}' is missing keys: {missing_keys}") - continue - - # Validate non-empty description - if not column_data.get("description"): - errors.append(f"Column '{column_name}' in table '{table_name}' has an empty description") - - # Validate 'null' field - if column_data["null"] not in VALID_NULL_VALUES: - errors.append(f"Column '{column_name}' in table '{table_name}' has invalid 'null' value: {column_data['null']}") - - # Optional: validate foreign keys - if "foreign_keys" in table_data: - if not isinstance(table_data["foreign_keys"], dict): - errors.append(f"Foreign keys for table '{table_name}' must be a dictionary") - else: - for fk_name, fk_data in table_data["foreign_keys"].items(): - for key in ("column", "referenced_table", "referenced_column"): - if key not in fk_data or not fk_data[key]: - errors.append(f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'") + errors.extend(_validate_table(table_name, table_data)) + + return errors + + +def _validate_table(table_name, table_data): + """Validate a single table's structure.""" + errors = [] + + if not table_data.get("description"): + errors.append(f"Table '{table_name}' is missing a description") + + if "columns" not in table_data or not isinstance(table_data["columns"], dict): + errors.append(f"Table '{table_name}' has no valid 'columns' definition") + return errors + + for column_name, column_data in table_data["columns"].items(): + errors.extend(_validate_column(table_name, column_name, column_data)) + + # Optional: validate foreign keys + if "foreign_keys" in table_data: + errors.extend(_validate_foreign_keys(table_name, table_data["foreign_keys"])) + + return errors + + +def _validate_column(table_name, column_name, column_data): + """Validate a single column's structure.""" + errors = [] + + # Check for missing required keys + missing_keys = REQUIRED_COLUMN_KEYS - column_data.keys() + if missing_keys: + errors.append( + f"Column '{column_name}' in table '{table_name}' " + f"is missing keys: {missing_keys}" + ) + return errors + + # Validate non-empty description + if not column_data.get("description"): + errors.append( + f"Column '{column_name}' in table '{table_name}' has an empty description" + ) + + # Validate 'null' field + if column_data["null"] not in VALID_NULL_VALUES: + errors.append( + f"Column '{column_name}' in table '{table_name}' " + f"has invalid 'null' value: {column_data['null']}" + ) + + return errors + + +def _validate_foreign_keys(table_name, foreign_keys): + """Validate foreign keys structure.""" + errors = [] + + if not isinstance(foreign_keys, dict): + errors.append( + f"Foreign keys for table '{table_name}' must be a dictionary" + ) + return errors + + for fk_name, fk_data in foreign_keys.items(): + for key in ("column", "referenced_table", "referenced_column"): + if key not in fk_data or not fk_data[key]: + errors.append( + f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'" + ) return errors diff --git a/api/static/css/chat.css b/api/static/css/chat.css index fea8e19b..c0fd2bf6 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -5,33 +5,93 @@ } :root { - /* FalkorDB brand colors - based on browser.falkordb.com */ - --falkor-primary: #7466FF; - /* FalkorDB primary teal */ - --falkor-secondary: #191919; - /* Dark navy blue - main background */ - --falkor-tertiary: #FF66B3; - /* Falkor Tertiary Color */ - --falkor-quaternary: #393939; - /* Falkor Quaternary Color */ - --dark-bg: black; - /* Slightly lighter navy for surfaces */ - --falkor-accent: #19B6C9; - /* Secondary teal for hover states */ - --falkor-border-primary: #7466FF; + /* Professional color palette - muted and business-appropriate */ + --falkor-primary: #5B6BC0; + /* Muted indigo - professional primary */ + --falkor-secondary: #1A1A1A; + /* Dark charcoal - main background */ + --falkor-tertiary: #B39DDB; + /* Muted lavender - subtle tertiary */ + --falkor-quaternary: #424242; + /* Medium gray - quaternary */ + --dark-bg: #0F0F0F; + /* Deep dark for surfaces */ + --falkor-accent: #26A69A; + /* Professional teal - subdued accent */ + --falkor-border-primary: #5B6BC0; /* Primary border color*/ - --falkor-border-secondary: #FF804D; - /* Secondary border color*/ - --falkor-border-tertiary: #FFFFFF; - /* Tertiary border color*/ - --text-primary: #FFFFFF; - /* Primary text color */ - --text-secondary: #CDD3DF; - /* Secondary text color */ - --text-tertiary: #525252; - /* Tertiary text color */ - --falkor-highlight: #20C9D8; - /* Highlight color */ + --falkor-border-secondary: #90A4AE; + /* Muted blue-gray - professional secondary border */ + --falkor-border-tertiary: #E0E0E0; + /* Light gray border */ + --text-primary: #F5F5F5; + /* Soft white text */ + --text-secondary: #B0BEC5; + /* Muted blue-gray text */ + --text-tertiary: #616161; + /* Medium gray text */ + --falkor-highlight: #4DB6AC; + /* Subtle teal highlight */ + --accent-green: #66BB6A; + /* Professional green */ + --icon-filter: invert(1); + /* Dark theme - invert icons for white appearance */ + --bg-tertiary: #2E2E2E; + /* Professional dark gray */ + --border-color: #616161; + /* Subtle border color */ +} + +/* Light theme variables */ +[data-theme="light"] { + --falkor-primary: #5B6BC0; + --falkor-secondary: #FAFAFA; + --falkor-tertiary: #B39DDB; + --falkor-quaternary: #F5F5F5; + --dark-bg: #FFFFFF; + --falkor-accent: #26A69A; + --falkor-border-primary: #5B6BC0; + --falkor-border-secondary: #90A4AE; + --falkor-border-tertiary: #424242; + --text-primary: #212121; + --text-secondary: #616161; + --text-tertiary: #9E9E9E; + --falkor-highlight: #4DB6AC; + --accent-green: #388E3C; + /* Professional dark green for light theme */ + --icon-filter: invert(0); + /* Light theme - no inversion for dark icons */ + --bg-tertiary: #F8F8F8; + /* Very light gray background */ + --border-color: #E0E0E0; + /* Light border color */ +} + +/* System theme detection */ +@media (prefers-color-scheme: light) { + [data-theme="system"] { + --falkor-primary: #5B6BC0; + --falkor-secondary: #FAFAFA; + --falkor-tertiary: #B39DDB; + --falkor-quaternary: #F5F5F5; + --dark-bg: #FFFFFF; + --falkor-accent: #26A69A; + --falkor-border-primary: #5B6BC0; + --falkor-border-secondary: #90A4AE; + --falkor-border-tertiary: #424242; + --text-primary: #212121; + --text-secondary: #616161; + --text-tertiary: #9E9E9E; + --falkor-highlight: #4DB6AC; + --accent-green: #388E3C; + /* Professional dark green for light theme */ + --icon-filter: invert(0); + /* Light theme - no inversion for dark icons */ + --bg-tertiary: #F8F8F8; + /* Very light gray background */ + --border-color: #E0E0E0; + /* Light border color */ + } } @font-face { @@ -50,6 +110,11 @@ body { overflow: hidden; } +/* Ensure all form elements inherit the consistent font */ +button, input, select, textarea { + font-family: inherit; +} + #container { height: 98%; width: 100%; @@ -61,7 +126,7 @@ body { #gradient { width: 100%; height: 2%; - background: linear-gradient(to right, #C15CFF, #FF5454); + background: linear-gradient(to right, var(--falkor-primary), var(--falkor-accent)); } .logo { @@ -69,6 +134,11 @@ body { width: auto; } +[data-theme="light"] .logo { + filter: invert(1); +} + + .chat-container { flex: 1 1 0; display: flex; @@ -83,8 +153,21 @@ body { transition: margin-left 0.4s cubic-bezier(0.4, 0, 0.2, 1); } -#menu-container.open~#chat-container { - margin-left: 0; +/* Mobile responsive adjustments */ +@media (max-width: 768px) { + .chat-container { + padding-right: 10px; + padding-left: 10px; + padding-top: 10px; + padding-bottom: 10px; + } +} + +/* Mobile: Menu overlays content (no pushing) */ +@media (max-width: 768px) { + #menu-container.open~#chat-container { + margin-left: 0; + } } .chat-header { @@ -94,7 +177,7 @@ body { justify-content: center; gap: 20px; background: var(--falkor-secondary); - color: white; + color: var(--text-primary); padding: 16px 20px; text-align: center; position: relative; @@ -122,19 +205,40 @@ body { } .user-message-container { - justify-content: flex-start; + justify-content: flex-end; + position: relative; +} + +/* Hide the default "User" text when avatar is present */ +.user-message-container.has-avatar::after { + display: none; +} + +/* User message avatar styling */ +.user-message-avatar { + height: 40px; + width: 40px; + border-radius: 50%; + object-fit: cover; + margin-left: 10px; + border: 2px solid var(--falkor-quaternary); + font-weight: 500; + font-size: 16px; + justify-content: center; + align-items: center; + display: flex; } .bot-message-container, .followup-message-container, .final-result-message-container { - justify-content: flex-end; + justify-content: flex-start; } -.user-message-container::before, -.bot-message-container::after, -.followup-message-container::after, -.final-result-message-container::after { +.user-message-container::after, +.bot-message-container::before, +.followup-message-container::before, +.final-result-message-container::before { height: 32px; width: 32px; content: 'Bot'; @@ -144,26 +248,29 @@ body { justify-content: center; font-weight: 500; font-size: 16px; - margin-left: 10px; + margin-right: 10px; border-radius: 100%; padding: 4px; } -.user-message-container::before { - margin-right: 10px; +.user-message-container::after { + margin-left: 10px; background: var(--falkor-quaternary); + content: 'User'; } -.bot-message-container::after { +.bot-message-container::before { background: color-mix(in srgb, var(--falkor-tertiary) 33%, transparent); } -.followup-message-container::after { - background: #492231; +.followup-message-container::before { + background: color-mix(in srgb, var(--falkor-tertiary) 40%, transparent); + border: 1px solid var(--falkor-tertiary); } -.final-result-message-container::after { - background: #1f4122; +.final-result-message-container::before { + background: color-mix(in srgb, var(--accent-green) 40%, transparent); + border: 1px solid var(--accent-green); } .loading-message-container::before { @@ -185,6 +292,38 @@ body { margin: 5px 0; line-height: 1.4; color: var(--text-primary); + word-wrap: break-word; + overflow-wrap: break-word; + overflow-x: auto; +} + +/* Mobile responsive messages */ +@media (max-width: 768px) { + .message { + max-width: 85%; + padding: 10px 12px; + font-size: 14px; + } + + .chat-header h1 { + font-size: 18px; + } + + #message-input { + font-size: 16px !important; + } + + #message-input::placeholder { + font-size: 16px !important; + } +} + +/* Styles for formatted text blocks */ +.sql-line, .array-line, .plain-line { + word-wrap: break-word; + overflow-wrap: break-word; + white-space: pre-wrap; + margin: 2px 0; } .bot-message { @@ -197,11 +336,13 @@ body { } .followup-message { - background: #492231; + background: color-mix(in srgb, var(--falkor-tertiary) 20%, transparent); + border-left: 3px solid var(--falkor-tertiary); } .final-result-message { - background: #1f4122; + background: color-mix(in srgb, var(--accent-green) 15%, transparent); + border-left: 3px solid var(--accent-green); } .chat-input { @@ -220,7 +361,8 @@ body { flex-grow: 1; padding: 10px; border-radius: 6px; - box-shadow: 0 0 20px 3px var(--falkor-tertiary); + box-shadow: 0 0 8px 1px var(--falkor-primary); + border: 1px solid var(--border-color); } .input-container.loading { @@ -229,8 +371,7 @@ body { #message-input { color: var(--text-primary); - width: 100%; - height: 100%; + flex-grow: 1; background-color: transparent; border: none; font-size: 18px !important; @@ -253,6 +394,19 @@ body { cursor: pointer; transition: opacity 0.2s; border: none; + width: 48px; + height: 48px; + overflow: hidden; + display: flex; + align-items: center; + justify-content: center; +} + +.input-button img { + filter: var(--icon-filter) brightness(0.8) saturate(0.3); + width: 100%; + height: 100%; + object-fit: contain; } .input-button:hover { @@ -263,6 +417,15 @@ body { display: none; } +/* Mobile responsive reset button */ +@media (max-width: 768px) { + .input-button { + width: 40px; + height: 40px; + } +} + + #menu-container { width: 0; min-width: 0; @@ -286,6 +449,22 @@ body { pointer-events: auto; } +/* Mobile responsive menu */ +@media (max-width: 768px) { + #menu-container { + position: fixed; + top: 0; + left: 0; + height: 100vh; + z-index: 999; + } + + #menu-container.open { + width: 80vw; + padding: 15px; + } +} + #menu-header { display: flex; flex-direction: row; @@ -303,22 +482,26 @@ body { position: absolute; top: 20px; left: 20px; - z-index: 10; -} - -.menu-trigger { - border: none; - height: 32px; - width: 32px; } -#side-menu-button img { - background: var(--falkor-secondary); +/* Mobile responsive side menu button */ +@media (max-width: 768px) { + #side-menu-button { + top: 15px; + left: 15px; + } } #menu-button img { rotate: 180deg; - background: var(--falkor-quaternary); + filter: brightness(0) invert(1); +} + +/* Menu user image styling */ +.menu-user-img { + width: 32px; + height: 32px; + object-fit: cover; } .menu-item { @@ -394,6 +577,14 @@ body { gap: 10px; } +/* Mobile responsive button container */ +@media (max-width: 768px) { + .button-container { + flex-direction: row; + gap: 8px; + } +} + #graph-select { height: 100%; padding: 8px 12px; @@ -410,6 +601,20 @@ body { background-position: calc(100% - 20px) center, calc(100% - 15px) center; background-size: 5px 5px, 5px 5px; background-repeat: no-repeat; + cursor:pointer; +} + +/* Mobile responsive select elements */ +@media (max-width: 768px) { + #graph-select, + #custom-file-upload, + #open-pg-modal { + min-width: 120px; + width: auto; + padding: 8px 10px; + font-size: 14px; + flex: 1; + } } #graph-select:focus { @@ -433,6 +638,28 @@ body { background-position: calc(100% - 20px) center, calc(100% - 15px) center; background-size: 5px 5px, 5px 5px; background-repeat: no-repeat; + cursor:pointer; +} + +#open-pg-modal { + height: 100%; + padding: 8px 12px; + border-radius: 6px; + border: 1px solid var(--text-primary); + font-size: 14px; + background-color: var(--falkor-secondary); + color: var(--text-primary); + transition: border-color 0.2s; + min-width: 180px; + appearance: none; + background-position: calc(100% - 20px) center, calc(100% - 15px) center; + background-size: 5px 5px, 5px 5px; + background-repeat: no-repeat; + cursor: pointer; + display: inline-block; + margin-bottom: 0; + margin-right: 0; + margin-left: 0; } #schema-upload:disabled+label { @@ -450,155 +677,262 @@ body { right: -60px; top: 50%; transform: translateY(-50%); - background: transparent; - color: var(--text-primary); - border: none; - cursor: pointer; - padding: 4px; } -#reset-button:disabled { - opacity: 0.5; - cursor: not-allowed; +#reset-button svg { + width: 20px; + height: 20px; } -.suggestions-input-container { - display: flex; - flex-direction: column; - gap: 10px; - width: 100%; +/* Mobile responsive reset button */ +@media (max-width: 768px) { + #reset-button { + position: relative; + right: auto; + top: auto; + transform: none; + margin-left: 10px; + } + + #reset-button svg { + width: 18px; + height: 18px; + } + + .chat-input { + padding: 12px 16px; + gap: 12px; + } } -.suggestions-container { - display: flex; - align-items: stretch; - gap: 10px; - list-style: none; - padding-left: 0; - margin: 0; - width: 100%; - min-height: 56px; +.action-button:hover { + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + background: var(--falkor-accent); + transform: translateY(-1px); } -.suggestion-item { +.action-button{ + z-index: 100; + width: 48px; + height: 48px; + border: none; + border-radius: 50%; + background: var(--falkor-quaternary); + color: var(--text-primary); + box-shadow: 0 2px 8px rgba(0,0,0,0.15); + cursor: pointer; + transition: all 0.2s ease; display: flex; align-items: center; justify-content: center; - min-height: 56px; - height: 100%; - max-height: none; - width: calc(33.333% - 7px); - min-width: calc(33.333% - 7px); - max-width: calc(33.333% - 7px); - overflow: hidden; - transition: max-height 0.3s ease-out, overflow 0.3s ease-out; - border: 1px solid var(--falkor-border-tertiary); - padding: 8px; - border-radius: 6px; + padding: 0; + position: relative; } -.suggestion-item button { - display: flex; +#reset-button:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.pg-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.5); + z-index: 3000; align-items: center; justify-content: center; +} +.pg-modal-content { + background: var(--falkor-secondary); + padding: 2em 5em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); + text-align: center; + min-width: 680px; +} + +/* Mobile responsive modal */ +@media (max-width: 768px) { + .pg-modal-content { + padding: 1.5em 2em; + min-width: 90vw; + margin: 0 5vw; + } +} +.pg-modal-title { + margin-bottom: 1em; + color: var(--text-primary); +} +.pg-modal-input { width: 100%; - height: 100%; + padding: 0.6em; + font-size: 1em; + border: 1px solid var(--border-color); + border-radius: 4px; + margin-bottom: 1.5em; + color: var(--text-primary); + background: var(--falkor-quaternary); +} +.pg-modal-actions { + display: flex; + justify-content: space-between; + gap: 1em; +} +.pg-modal-btn { + flex: 1; + padding: 0.5em 0; border: none; - background: transparent; + border-radius: 4px; + font-size: 1em; + font-weight: bold; cursor: pointer; + transition: background 0.2s; +} +.pg-modal-connect { + background: #4285F4; + color: #fff; +} +.pg-modal-connect:hover { + background: #3367d6; +} +.pg-modal-connect:disabled { + background: #6c8db8; + cursor: not-allowed; +} +.pg-modal-loading-spinner { + display: flex; + align-items: center; + justify-content: center; + gap: 8px; +} +.spinner { + width: 16px; + height: 16px; + border: 2px solid #ffffff40; + border-top: 2px solid #ffffff; + border-radius: 50%; + animation: spin 1s linear infinite; +} +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} +.pg-modal-cancel { + background: var(--bg-tertiary); color: var(--text-primary); + border: 1px solid var(--border-color); } - -.suggestion-item button p { - margin: 0; +.pg-modal-cancel:hover { + background: var(--border-color); +} +.pg-modal-cancel:disabled { + background: var(--bg-tertiary); + color: var(--text-secondary); + cursor: not-allowed; + border: 1px solid var(--text-secondary); +} +.pg-modal-input:disabled { + background: var(--bg-tertiary); + color: var(--text-secondary); + cursor: not-allowed; +} +.google-login-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.6); + z-index: 1000; + align-items: center; + justify-content: center; +} +.google-login-modal-content { + background: var(--falkor-secondary); + padding: 2em 3em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); text-align: center; - width: 100%; - white-space: normal; - -webkit-line-clamp: unset; - line-clamp: unset; - -webkit-box-orient: unset; - word-wrap: break-word; - overflow-wrap: break-word; - hyphens: auto; - line-height: 1.3; - max-width: 100%; - font-size: 14px; - display: -webkit-box; - -webkit-line-clamp: 2; - -webkit-box-orient: vertical; - line-clamp: 2; - overflow: hidden; - text-overflow: ellipsis; + color: var(--text-primary); } -/* Loading state for suggestions */ -.suggestion-item.loading { - border: 1px solid var(--falkor-border-secondary); - background: linear-gradient(90deg, - rgba(136, 135, 135, 0.1) 0%, - rgba(136, 135, 135, 0.3) 50%, - rgba(136, 135, 135, 0.1) 100%); - background-size: 200% 100%; - animation: suggestion-loading 1.5s ease-in-out infinite; - min-height: 56px; +.google-login-modal-content h2 { + color: var(--text-primary); + margin-bottom: 0.5em; + font-size: 1.5em; } -.suggestion-item.loading button { - cursor: default; - opacity: 0.7; +.google-login-modal-content p { + color: var(--text-secondary); + margin-bottom: 1em; + font-size: 1em; } - -.suggestion-item.loading button p { - min-height: 20px; - width: 80%; - background: rgba(136, 135, 135, 0.2); +.google-login-btn { + display: flex; + align-items: center; + justify-content: flex-start; + gap: 10px; + margin-top: 1em; + padding: 0.7em 2em; + background: #4285F4; + color: #fff; border-radius: 4px; - animation: pulse 1.5s ease-in-out infinite; - font-size: 14px; + font-size: 1.1em; + text-decoration: none; + font-weight: 500; + transition: background 0.2s; } - -/* Loaded state animation */ -.suggestion-item.loaded { - animation: suggestion-fade-in 0.5s ease-out; - min-height: 56px; - max-height: none; - overflow: visible; - align-items: flex-start; +.google-login-btn:hover { + background: #3367d6; } - -.suggestion-item.loaded.active, -.suggestion-item.loaded:hover { - border: 1px solid var(--falkor-border-primary); +.google-login-logo { + height: 20px; + margin-right: 10px; + margin-left: 0; + vertical-align: middle; + display: inline-block; } -.suggestion-item.loaded button { +.github-login-btn { + display: flex; align-items: center; - padding: 8px 4px; + justify-content: flex-start; + gap: 10px; + padding: 0.7em 2em; + background: #24292f; + color: #fff; + border-radius: 4px; + font-size: 1.1em; + text-decoration: none; + font-weight: 500; + transition: background 0.2s; } - -.suggestion-item.loaded button:focus { - outline: none; +.github-login-btn:hover { + background: #171a1d; } - -.suggestion-item.loaded button p { - -webkit-line-clamp: none; - line-clamp: none; - display: block; - white-space: normal; - max-height: none; +.github-login-logo { + height: 20px; + margin-right: 10px; + margin-left: 0; + vertical-align: middle; + display: inline-block; } - @keyframes shadow-fade { 0% { - box-shadow: 0 0 20px 3px var(--falkor-tertiary) + box-shadow: 0 0 8px 1px var(--falkor-primary) } 50% { - box-shadow: 0 0 40px 6px var(--falkor-tertiary) + box-shadow: 0 0 12px 2px var(--falkor-primary) } 100% { - box-shadow: 0 0 20px 3px var(--falkor-tertiary) + box-shadow: 0 0 8px 1px var(--falkor-primary) } } @@ -616,16 +950,6 @@ body { } } -@keyframes suggestion-loading { - 0% { - background-position: -200% 0; - } - - 100% { - background-position: 200% 0; - } -} - @keyframes pulse { 0%, @@ -638,18 +962,6 @@ body { } } -@keyframes suggestion-fade-in { - 0% { - opacity: 0; - transform: translateY(10px); - } - - 100% { - opacity: 1; - transform: translateY(0); - } -} - ::-webkit-scrollbar { width: 20px; height: 20px; @@ -668,4 +980,344 @@ body { ::-webkit-scrollbar-thumb:hover { background-color: #a8bbbf; +} + +.pg-connect-btn { + margin-left: 12px; + padding: 0.4em 0.8em; + background: #f5f5f5; + border: 1px solid #ccc; + border-radius: 4px; + font-size: 1em; + color: #333; + cursor: pointer; + transition: background 0.2s, border 0.2s; +} + +.pg-connect-btn:hover { + background: #eaeaea; + border-color: #888; +} + +.logout-btn { + position: fixed; + top: 20px; + right: 30px; + z-index: 2000; + padding: 0.5em 1.2em; + background: #e74c3c; + color: #fff; + border-radius: 5px; + text-decoration: none; + font-weight: bold; + box-shadow: 0 2px 8px rgba(0,0,0,0.08); + transition: background 0.2s; +} + +.logout-btn:hover { + background: #c0392b; +} + +/* User Profile Button Styles */ +.user-profile-btn { + position: fixed; + top: 20px; + right: 30px; + z-index: 2000; + width: 48px; + height: 48px; + border: none; + border-radius: 50%; + background: #fff; + box-shadow: 0 2px 8px rgba(0,0,0,0.15); + cursor: pointer; + transition: all 0.2s ease; + padding: 2px; +} + +.user-profile-btn:hover { + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + transform: translateY(-1px); +} + +.user-profile-img { + width: 100%; + height: 100%; + border-radius: 50%; + object-fit: cover; + font-weight: 500; + font-size: 22px; +} + +/* User Profile Dropdown */ +.user-profile-dropdown { + position: fixed; + top: 80px; + right: 30px; + z-index: 1999; + background: var(--falkor-secondary); + border: 1px solid var(--falkor-border-primary); + border-radius: 8px; + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + min-width: 200px; + display: none; +} + +/* Mobile responsive dropdown */ +@media (max-width: 768px) { + .user-profile-dropdown { + top: 65px; + right: 15px; + min-width: 180px; + } +} + +.user-profile-dropdown.show { + display: block; +} + +.user-profile-info { + padding: 15px; + border-bottom: 1px solid var(--falkor-quaternary); +} + +.user-profile-name { + color: var(--text-primary); + font-weight: bold; + margin-bottom: 5px; +} + +.user-profile-email { + color: var(--text-secondary); + font-size: 0.9em; +} + +.user-profile-actions { + padding: 10px; +} + +.user-profile-logout { + width: 100%; + padding: 10px; + background: #D32F2F; + color: #fff; + border: none; + border-radius: 5px; + cursor: pointer; + font-weight: bold; + transition: background 0.2s; +} + +.user-profile-logout:hover { + background: #B71C1C; +} + +/* Theme Toggle Button Styles */ +#theme-toggle-btn { + position: fixed; + top: 20px; + right: 90px; +} + +.theme-icon { + width: 20px; + height: 20px; + color: var(--text-primary); + transition: all 0.3s ease; +} + +/* Mobile responsive user profile and theme toggle */ +@media (max-width: 768px) { + .action-button { + width: 40px; + height: 40px; + } + .user-profile-btn { + top: 15px; + right: 15px; + width: 40px; + height: 40px; + } + + #theme-toggle-btn { + top: 15px; + right: 80px; + } + + .theme-icon { + width: 18px; + height: 18px; + } +} + +/* Theme icon states */ +[data-theme="dark"] .theme-icon .sun, +[data-theme="system"] .theme-icon .sun { + display: none; +} + +[data-theme="dark"] .theme-icon .moon { + display: block; +} + +[data-theme="dark"] .theme-icon .system, +[data-theme="light"] .theme-icon .system { + display: none; +} + +[data-theme="light"] .theme-icon .sun { + display: block; +} + +[data-theme="light"] .theme-icon .moon, +[data-theme="system"] .theme-icon .moon { + display: none; +} + +[data-theme="system"] .theme-icon .system { + display: block; +} + +[data-theme="system"] .theme-icon .sun { + display: none; +} + +/* Destructive Confirmation Styles */ +.destructive-confirmation-container { + transition: all 0.3s ease; +} + +.destructive-confirmation-message { + border: 1px solid #D32F2F; + border-radius: 8px; + background: linear-gradient(135deg, var(--bg-tertiary), var(--falkor-quaternary)); +} + +.destructive-confirmation { + padding: 20px; + transition: all 0.3s ease; +} + +.confirmation-text { + margin-bottom: 20px; + line-height: 1.6; + color: #ffffff; + font-size: 14px; +} + +.confirmation-text strong { + color: #FFCDD2; +} + +.confirmation-buttons { + display: flex; + gap: 15px; + justify-content: center; + margin-top: 20px; +} + +.confirm-btn, .cancel-btn { + padding: 12px 24px; + border: none; + border-radius: 6px; + font-family: 'Fira Code', monospace; + font-size: 14px; + font-weight: bold; + cursor: pointer; + transition: all 0.3s ease; + min-width: 140px; +} + +.confirm-btn { + background: #D32F2F; + color: white; + border: 1px solid #D32F2F; +} + +.confirm-btn:hover { + background: #B71C1C; + border-color: #B71C1C; + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(211, 47, 47, 0.3); +} + +.cancel-btn { + background: transparent; + color: var(--text-primary); + border: 1px solid var(--border-color); +} + +.cancel-btn:hover { + background: var(--bg-tertiary); + border-color: var(--text-secondary); + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2); +} + +.confirm-btn:active, .cancel-btn:active { + transform: translateY(0); + box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3); +} + +.confirm-btn:disabled, .cancel-btn:disabled { + background: #cccccc; + color: #888888; + border-color: #cccccc; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +.confirm-btn:disabled:hover, .cancel-btn:disabled:hover { + background: #cccccc; + color: #888888; + border-color: #cccccc; + transform: none; + box-shadow: none; +} + +/* Reset Confirmation Modal */ +.reset-confirmation-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.6); + z-index: 2000; + align-items: center; + justify-content: center; +} + +.reset-confirmation-modal-content { + background: var(--falkor-secondary); + padding: 2em 3em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); + text-align: center; + color: var(--text-primary); + min-width: 400px; +} + +/* Mobile responsive reset modal */ +@media (max-width: 768px) { + .reset-confirmation-modal-content { + padding: 1.5em 2em; + min-width: 90vw; + margin: 0 5vw; + } +} + +.reset-confirmation-modal-content h3 { + color: var(--text-primary); + margin-bottom: 0.5em; + font-size: 1.3em; +} + +.reset-confirmation-modal-content p { + color: var(--text-secondary); + margin-bottom: 1.5em; + font-size: 1em; + line-height: 1.5; } \ No newline at end of file diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 2d8f19ba..cda533e0 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -13,10 +13,8 @@ const sideMenuButton = document.getElementById('side-menu-button'); const menuButton = document.getElementById('menu-button'); const menuContainer = document.getElementById('menu-container'); const chatContainer = document.getElementById('chat-container'); -const suggestionsContainer = document.getElementById('suggestions-container'); const expInstructions = document.getElementById('instructions-textarea'); const inputContainer = document.getElementById('input-container'); -const suggestionItems = document.querySelectorAll('.suggestion-item'); let questions_history = []; let result_history = []; @@ -27,29 +25,37 @@ const MESSAGE_DELIMITER = '|||FALKORDB_MESSAGE_BOUNDARY|||'; const urlParams = new URLSearchParams(window.location.search); -const TOKEN = urlParams.get('token'); - -function addMessage(message, isUser = false, isFollowup = false, isFinalResult = false, isLoading = false) { +function addMessage(message, isUser = false, isFollowup = false, isFinalResult = false, isLoading = false, userInfo = null) { const messageDiv = document.createElement('div'); const messageDivContainer = document.createElement('div'); messageDiv.className = "message"; messageDivContainer.className = "message-container"; + let userAvatar = null; + if (isFollowup) { messageDivContainer.className += " followup-message-container"; messageDiv.className += " followup-message"; messageDiv.textContent = message; } else if (isUser) { - suggestionsContainer.style.display = 'none'; messageDivContainer.className += " user-message-container"; messageDiv.className += " user-message"; + + // Prepare user avatar if userInfo is provided + if (userInfo && userInfo.picture) { + userAvatar = document.createElement('img'); + userAvatar.src = userInfo.picture; + userAvatar.alt = userInfo.name?.charAt(0).toUpperCase() || 'User'; + userAvatar.className = 'user-message-avatar'; + messageDivContainer.classList.add('has-avatar'); + } + questions_history.push(message); } else if (isFinalResult) { result_history.push(message); messageDivContainer.className += " final-result-message-container"; messageDiv.className += " final-result-message"; - // messageDiv.textContent = message; } else { messageDivContainer.className += " bot-message-container"; messageDiv.className += " bot-message"; @@ -59,7 +65,7 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = } } - const block = formatBlock(message) + const block = formatBlock(message); if (block) { block.forEach(lineDiv => { @@ -71,9 +77,14 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = if (!isLoading) { messageDivContainer.appendChild(messageDiv); + if (userAvatar) { + messageDivContainer.appendChild(userAvatar); + } } + chatMessages.appendChild(messageDivContainer); chatMessages.scrollTop = chatMessages.scrollHeight; + return messageDiv; } @@ -132,35 +143,23 @@ function formatBlock(text) { return lineDiv; }); } - if (text.includes('\n')) { - return text.split('\n').map((line, i) => { - const lineDiv = document.createElement('div'); - lineDiv.className = 'plain-line'; - lineDiv.textContent = line; - return lineDiv; - }); - } - if (text.includes('\n')) { - return text.split('\n').map((line, i) => { - const lineDiv = document.createElement('div'); - lineDiv.className = 'plain-line'; - lineDiv.textContent = line; - return lineDiv; - }); - } } function initChat() { messageInput.value = ''; - suggestionItems.forEach(item => { - item.classList.remove('active'); - }); chatMessages.innerHTML = ''; - [confValue, expValue, missValue, ambValue].forEach((element) => { + [confValue, expValue, missValue].forEach((element) => { element.innerHTML = ''; }); - addMessage('Hello! How can I help you today?', false); - suggestionsContainer.style.display = 'flex'; + + // Check if we have graphs available + const graphSelect = document.getElementById("graph-select"); + if (graphSelect && graphSelect.options.length > 0 && graphSelect.options[0].value) { + addMessage('Hello! How can I help you today?', false); + } else { + addMessage('Hello! Please select a graph from the dropdown above or upload a schema to get started.', false); + } + questions_history = []; result_history = []; } @@ -180,13 +179,20 @@ async function sendMessage() { const message = messageInput.value.trim(); if (!message) return; + // Check if a graph is selected + const selectedValue = document.getElementById("graph-select").value; + if (!selectedValue) { + addMessage("Please select a graph from the dropdown before sending a message.", false, true); + return; + } + // Cancel any ongoing request if (currentRequestController) { currentRequestController.abort(); } // Add user message to chat - addMessage(message, true); + addMessage(message, true, false, false, false, window.currentUser); messageInput.value = ''; // Show typing indicator @@ -201,13 +207,12 @@ async function sendMessage() { }); try { - const selectedValue = document.getElementById("graph-select").value; // Create an AbortController for this request currentRequestController = new AbortController(); // Use fetch with streaming response (GET method) - const response = await fetch('/graphs/' + selectedValue + '?q=' + encodeURIComponent(message) + '&token=' + TOKEN, { + const response = await fetch('/graphs/' + selectedValue + '?q=' + encodeURIComponent(message), { method: 'POST', headers: { 'Content-Type': 'application/json' @@ -310,6 +315,22 @@ async function sendMessage() { ambValue.textContent = "N/A"; // graph.Labels.findIndex(l => l.name === cat.name)(step.message, false, true); addMessage(step.message, false, true); + } else if (step.type === 'query_result') { + // Handle query result + if (step.data) { + addMessage(`Query Result: ${JSON.stringify(step.data)}`, false, false, true); + } else { + addMessage("No results found for the query.", false); + } + } else if (step.type === 'ai_response') { + // Handle AI-generated user-friendly response + addMessage(step.message, false, false, true); + } else if (step.type === 'destructive_confirmation') { + // Handle destructive operation confirmation request + addDestructiveConfirmationMessage(step); + } else if (step.type === 'operation_cancelled') { + // Handle cancelled operation + addMessage(step.message, false, true); } else { // Default handling addMessage(step.message || JSON.stringify(step), false); @@ -348,16 +369,27 @@ async function sendMessage() { } function toggleMenu() { + // Check if we're on mobile (768px breakpoint to match CSS) + const isMobile = window.innerWidth <= 768; + if (!menuContainer.classList.contains('open')) { menuContainer.classList.add('open'); sideMenuButton.style.display = 'none'; - chatContainer.style.paddingRight = '10%'; - chatContainer.style.paddingLeft = '10%'; + + // Only adjust padding on desktop, not mobile (mobile uses overlay) + if (!isMobile) { + chatContainer.style.paddingRight = '10%'; + chatContainer.style.paddingLeft = '10%'; + } } else { menuContainer.classList.remove('open'); sideMenuButton.style.display = 'block'; - chatContainer.style.paddingRight = '20%'; - chatContainer.style.paddingLeft = '20%'; + + // Only adjust padding on desktop, not mobile (mobile uses overlay) + if (!isMobile) { + chatContainer.style.paddingRight = '20%'; + chatContainer.style.paddingLeft = '20%'; + } } } @@ -378,9 +410,145 @@ function pauseRequest() { // Add a message indicating the request was paused addMessage("Request was paused by user.", false, true); + } +} + +function addDestructiveConfirmationMessage(step) { + const messageDiv = document.createElement('div'); + const messageDivContainer = document.createElement('div'); + + messageDivContainer.className = "message-container bot-message-container destructive-confirmation-container"; + messageDiv.className = "message bot-message destructive-confirmation-message"; + + // Generate a unique ID for this confirmation dialog + const confirmationId = 'confirmation-' + Date.now(); + + // Create the confirmation UI + const confirmationHTML = ` +
+
${step.message.replace(/\n/g, '
')}
+
+ + +
+
+ `; + + messageDiv.innerHTML = confirmationHTML; + + messageDivContainer.appendChild(messageDiv); + chatMessages.appendChild(messageDivContainer); + chatMessages.scrollTop = chatMessages.scrollHeight; + + // Disable the main input while waiting for confirmation + messageInput.disabled = true; + submitButton.disabled = true; +} + +async function handleDestructiveConfirmation(confirmation, sqlQuery, confirmationId) { + // Find the specific confirmation dialog using the unique ID + const confirmationDialog = document.querySelector(`[data-confirmation-id="${confirmationId}"]`); + if (confirmationDialog) { + // Disable both confirmation buttons within this specific dialog + const confirmBtn = confirmationDialog.querySelector('.confirm-btn'); + const cancelBtn = confirmationDialog.querySelector('.cancel-btn'); + if (confirmBtn) confirmBtn.disabled = true; + if (cancelBtn) cancelBtn.disabled = true; + } + + // Re-enable the input + messageInput.disabled = false; + submitButton.disabled = false; + + // Add user's choice as a message + addMessage(`User choice: ${confirmation}`, true, false, false, false, window.currentUser); + + if (confirmation === 'CANCEL') { + addMessage("Operation cancelled. The destructive SQL query was not executed.", false, true); + return; + } + + // If confirmed, send confirmation to server + try { + const selectedValue = document.getElementById("graph-select").value; + + const response = await fetch('/graphs/' + selectedValue + '/confirm', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + confirmation: confirmation, + sql_query: sqlQuery, + chat: questions_history + }) + }); - // Show suggestions again since we're ready for new input - suggestionsContainer.style.display = 'flex'; + if (!response.ok) { + throw new Error(`Server responded with ${response.status}`); + } + + // Process the streaming response + const reader = response.body.getReader(); + let decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + if (buffer.trim()) { + try { + const step = JSON.parse(buffer); + addMessage(step.message || JSON.stringify(step), false); + } catch (e) { + addMessage(buffer, false); + } + } + break; + } + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + let delimiterIndex; + while ((delimiterIndex = buffer.indexOf(MESSAGE_DELIMITER)) !== -1) { + const message = buffer.slice(0, delimiterIndex).trim(); + buffer = buffer.slice(delimiterIndex + MESSAGE_DELIMITER.length); + + if (!message) continue; + + try { + const step = JSON.parse(message); + + if (step.type === 'reasoning_step') { + addMessage(step.message, false); + } else if (step.type === 'query_result') { + if (step.data) { + addMessage(`Query Result: ${JSON.stringify(step.data)}`, false, false, true); + } else { + addMessage("No results found for the query.", false); + } + } else if (step.type === 'ai_response') { + addMessage(step.message, false, false, true); + } else if (step.type === 'error') { + addMessage(`Error: ${step.message}`, false, true); + } else { + addMessage(step.message || JSON.stringify(step), false); + } + } catch (e) { + addMessage("Failed: " + message, false); + } + } + } + + } catch (error) { + console.error('Error:', error); + addMessage('Sorry, there was an error processing the confirmation: ' + error.message, false); } } @@ -393,35 +561,47 @@ messageInput.addEventListener('keypress', (e) => { } }); -messageInput.addEventListener('input', (e) => { - suggestionItems.forEach(item => { - if (e.target.value && item.querySelector('p').textContent === e.target.value) { - item.classList.add('active'); - } else { - item.classList.remove('active'); - } - }); -}) - menuButton.addEventListener('click', toggleMenu); sideMenuButton.addEventListener('click', toggleMenu); -newChatButton.addEventListener('click', initChat); - -// Add event listener to each suggestion item -suggestionItems.forEach(item => { - item.addEventListener('click', () => { - // Set the value of the message input to the text of the clicked suggestion item - const text = item.querySelector('p').textContent; - messageInput.value = text; - // Remove 'active' from all suggestion items - document.querySelectorAll('.suggestion-item.active').forEach(item => { - item.classList.remove('active'); - }); - // Add 'active' to the clicked suggestion item - item.classList.add('active'); - }); +// Reset confirmation modal elements +const resetConfirmationModal = document.getElementById('reset-confirmation-modal'); +const resetConfirmBtn = document.getElementById('reset-confirm-btn'); +const resetCancelBtn = document.getElementById('reset-cancel-btn'); + +// Show reset confirmation modal instead of directly resetting +newChatButton.addEventListener('click', () => { + resetConfirmationModal.style.display = 'flex'; + // Focus the Reset Session button when modal opens + setTimeout(() => { + resetConfirmBtn.focus(); + }, 100); // Small delay to ensure modal is fully rendered +}); + +// Handle reset confirmation +resetConfirmBtn.addEventListener('click', () => { + resetConfirmationModal.style.display = 'none'; + initChat(); +}); + +// Handle reset cancellation +resetCancelBtn.addEventListener('click', () => { + resetConfirmationModal.style.display = 'none'; +}); + +// Close modal when clicking outside of it +resetConfirmationModal.addEventListener('click', (e) => { + if (e.target === resetConfirmationModal) { + resetConfirmationModal.style.display = 'none'; + } +}); + +// Close modal with Escape key +document.addEventListener('keydown', (e) => { + if (e.key === 'Escape' && resetConfirmationModal.style.display === 'flex') { + resetConfirmationModal.style.display = 'none'; + } }); document.addEventListener("DOMContentLoaded", function () { @@ -429,10 +609,36 @@ document.addEventListener("DOMContentLoaded", function () { const graphSelect = document.getElementById("graph-select"); // Fetch available graphs - fetch("/graphs?token=" + TOKEN) - .then(response => response.json()) + fetch("/graphs") + .then(response => { + if (!response.ok) { + if (response.status === 401) { + throw new Error("Authentication required. Please log in to access graphs."); + } + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + return response.json(); + }) .then(data => { graphSelect.innerHTML = ""; + + if (!data || data.length === 0) { + // No graphs available + const option = document.createElement("option"); + option.value = ""; + option.textContent = "No graphs available"; + option.disabled = true; + graphSelect.appendChild(option); + + // Disable chat input when no graphs are available + messageInput.disabled = true; + submitButton.disabled = true; + messageInput.placeholder = "Please upload a schema or connect a database to start chatting"; + + addMessage("No graphs are available. Please upload a schema file or connect to a database to get started.", false); + return; + } + data.forEach(graph => { const option = document.createElement("option"); option.value = graph; @@ -441,92 +647,36 @@ document.addEventListener("DOMContentLoaded", function () { graphSelect.appendChild(option); }); - // Fetch suggestions for the first graph (if any) - if (data.length > 0) { - fetchSuggestions(); - } + // Re-enable chat input when graphs are available + messageInput.disabled = false; + submitButton.disabled = false; + messageInput.placeholder = "Describe the SQL query you want..."; }) .catch(error => { console.error("Error fetching graphs:", error); - addMessage("Sorry, there was an error fetching the available graphs: " + error.message, false); - }); - - // Function to fetch suggestions based on selected graph - function fetchSuggestions() { - const graphSelect = document.getElementById("graph-select"); - const selectedGraph = graphSelect.value; - - if (!selectedGraph) { - // Hide suggestions if no graph is selected - suggestionItems.forEach(item => { - item.style.display = 'none'; - }); - return; - } - - suggestionItems.forEach(item => { - item.style.display = 'flex'; - item.classList.remove('loaded'); - item.classList.add('loading'); - const button = item.querySelector('button'); - const p = item.querySelector('p'); - button.title = "Loading suggestion..."; - p.textContent = ""; + + // Show appropriate error message and disable input + if (error.message.includes("Authentication required")) { + addMessage("Authentication required. Please log in to access your graphs.", false); + // Don't disable input for auth errors as user needs to log in + } else { + addMessage("Sorry, there was an error fetching the available graphs: " + error.message, false); + messageInput.disabled = true; + submitButton.disabled = true; + messageInput.placeholder = "Cannot connect to server"; + } + + // Add a placeholder option to show the error state + graphSelect.innerHTML = ""; + const option = document.createElement("option"); + option.value = ""; + option.textContent = error.message.includes("Authentication") ? "Please log in" : "Error loading graphs"; + option.disabled = true; + graphSelect.appendChild(option); }); - // Fetch suggestions for the selected graph - fetch(`/suggestions?token=${TOKEN}&graph_id=${selectedGraph}`) - .then(response => response.json()) - .then(suggestions => { - // If no suggestions for this graph, hide the suggestions - if (!suggestions || suggestions.length === 0) { - suggestionItems.forEach(item => { - item.style.display = 'none'; - }); - return; - } - - // Hide unused suggestion slots - for (let i = suggestions.length; i < suggestionItems.length; i++) { - suggestionItems[i].style.display = 'none'; - } - - // Update each suggestion with fetched data and add loaded styling - suggestions.forEach((suggestion, index) => { - if (suggestionItems[index]) { - const item = suggestionItems[index]; - const button = item.querySelector('button'); - const p = item.querySelector('p'); - - // Add a slight delay for staggered animation - setTimeout(() => { - // Remove loading state and add content - item.classList.remove('loading'); - item.classList.add('loaded'); - - // Update content - p.textContent = suggestion; - button.title = suggestion; - - // Enable click functionality - button.style.cursor = 'pointer'; - }, index * 300); // 300ms delay between each suggestion - } - }); - }) - .catch(error => { - console.error("Error fetching suggestions:", error); - - // Hide suggestions on error - suggestionItems.forEach(item => { - item.style.display = 'none'; - }); - }); - } - graphSelect.addEventListener("change", function () { initChat(); - fetchSuggestions(); // Fetch new suggestions when graph changes }); }); @@ -540,7 +690,7 @@ fileUpload.addEventListener('change', function (e) { const formData = new FormData(); formData.append('file', file); - fetch("/graphs?token=" + TOKEN, { + fetch("/graphs", { method: 'POST', body: formData, // ✅ Correct, no need to set Content-Type manually }).then(response => { @@ -551,4 +701,216 @@ fileUpload.addEventListener('change', function (e) { console.error('Error uploading file:', error); addMessage('Sorry, there was an error uploading your file: ' + error.message, false); }); +}); + +document.addEventListener('DOMContentLoaded', function() { + // Authentication modal logic + var isAuthenticated = window.isAuthenticated !== undefined ? window.isAuthenticated : false; + var googleLoginModal = document.getElementById('google-login-modal'); + var container = document.getElementById('container'); + if (googleLoginModal && container) { + if (!isAuthenticated) { + googleLoginModal.style.display = 'flex'; + container.style.filter = 'blur(2px)'; + } else { + googleLoginModal.style.display = 'none'; + container.style.filter = ''; + } + } + // Postgres modal logic + var pgModal = document.getElementById('pg-modal'); + var openPgModalBtn = document.getElementById('open-pg-modal'); + var cancelPgModalBtn = document.getElementById('pg-modal-cancel'); + var connectPgModalBtn = document.getElementById('pg-modal-connect'); + var pgUrlInput = document.getElementById('pg-url-input'); + if (openPgModalBtn && pgModal) { + openPgModalBtn.addEventListener('click', function() { + pgModal.style.display = 'flex'; + // Focus the input field when modal opens + if (pgUrlInput) { + setTimeout(() => { + pgUrlInput.focus(); + }, 100); // Small delay to ensure modal is fully rendered + } + }); + } + if (cancelPgModalBtn && pgModal) { + cancelPgModalBtn.addEventListener('click', function() { + pgModal.style.display = 'none'; + }); + } + // Allow closing Postgres modal with Escape key + document.addEventListener('keydown', function(e) { + if (pgModal && pgModal.style.display === 'flex' && e.key === 'Escape') { + pgModal.style.display = 'none'; + } + }); + // Do NOT allow closing Google login modal with Escape or any other means except successful login + + // Handle Connect button for Postgres modal + if (connectPgModalBtn && pgUrlInput && pgModal) { + // Add Enter key support for the input field + pgUrlInput.addEventListener('keypress', function(e) { + if (e.key === 'Enter') { + connectPgModalBtn.click(); + } + }); + + connectPgModalBtn.addEventListener('click', function() { + const pgUrl = pgUrlInput.value.trim(); + if (!pgUrl) { + alert('Please enter a Postgres URL.'); + return; + } + + // Show loading state + const connectText = connectPgModalBtn.querySelector('.pg-modal-connect-text'); + const loadingSpinner = connectPgModalBtn.querySelector('.pg-modal-loading-spinner'); + const cancelBtn = document.getElementById('pg-modal-cancel'); + + connectText.style.display = 'none'; + loadingSpinner.style.display = 'flex'; + connectPgModalBtn.disabled = true; + cancelBtn.disabled = true; + pgUrlInput.disabled = true; + + fetch('/database', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ url: pgUrl }) + }) + .then(response => response.json()) + .then(data => { + // Reset loading state + connectText.style.display = 'inline'; + loadingSpinner.style.display = 'none'; + connectPgModalBtn.disabled = false; + cancelBtn.disabled = false; + pgUrlInput.disabled = false; + + if (data.success) { + pgModal.style.display = 'none'; // Close modal on success + // Refresh the graph list to show the new database + location.reload(); + } else { + alert('Failed to connect: ' + (data.error || 'Unknown error')); + } + }) + .catch(error => { + // Reset loading state on error + connectText.style.display = 'inline'; + loadingSpinner.style.display = 'none'; + connectPgModalBtn.disabled = false; + cancelBtn.disabled = false; + pgUrlInput.disabled = false; + + alert('Error connecting to database: ' + error.message); + }); + }); + } +}); + +// User Profile Dropdown Functionality +document.addEventListener('DOMContentLoaded', function() { + const userProfileBtn = document.getElementById('user-profile-btn'); + const userProfileDropdown = document.getElementById('user-profile-dropdown'); + + if (userProfileBtn && userProfileDropdown) { + // Toggle dropdown when profile button is clicked + userProfileBtn.addEventListener('click', function(e) { + e.stopPropagation(); + userProfileDropdown.classList.toggle('show'); + }); + + // Close dropdown when clicking outside + document.addEventListener('click', function(e) { + if (!userProfileBtn.contains(e.target) && !userProfileDropdown.contains(e.target)) { + userProfileDropdown.classList.remove('show'); + } + }); + + // Close dropdown with Escape key + document.addEventListener('keydown', function(e) { + if (e.key === 'Escape' && userProfileDropdown.classList.contains('show')) { + userProfileDropdown.classList.remove('show'); + } + }); + + // Prevent dropdown from closing when clicking inside it + userProfileDropdown.addEventListener('click', function(e) { + e.stopPropagation(); + }); + } +}); + +// Theme Toggle Functionality +document.addEventListener('DOMContentLoaded', function() { + const themeToggleBtn = document.getElementById('theme-toggle-btn'); + + // Get theme from localStorage or default to 'system' + const currentTheme = localStorage.getItem('theme') || 'system'; + document.documentElement.setAttribute('data-theme', currentTheme); + + if (themeToggleBtn) { + themeToggleBtn.addEventListener('click', function() { + const currentTheme = document.documentElement.getAttribute('data-theme'); + let newTheme; + + // Cycle through themes: dark -> light -> system -> dark + switch (currentTheme) { + case 'dark': + newTheme = 'light'; + break; + case 'light': + newTheme = 'system'; + break; + case 'system': + default: + newTheme = 'dark'; + break; + } + + document.documentElement.setAttribute('data-theme', newTheme); + localStorage.setItem('theme', newTheme); + + // Update button title + const titles = { + 'dark': 'Switch to Light Mode', + 'light': 'Switch to System Mode', + 'system': 'Switch to Dark Mode' + }; + themeToggleBtn.title = titles[newTheme]; + }); + + // Set initial title + const titles = { + 'dark': 'Switch to Light Mode', + 'light': 'Switch to System Mode', + 'system': 'Switch to Dark Mode' + }; + themeToggleBtn.title = titles[currentTheme]; + } +}); + +// Handle window resize to ensure proper menu behavior across breakpoints +window.addEventListener('resize', function() { + const isMobile = window.innerWidth <= 768; + + // If menu is open and we switch to mobile, remove any desktop padding + if (isMobile && menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = ''; + chatContainer.style.paddingLeft = ''; + } + // If menu is open and we switch to desktop, apply desktop padding + else if (!isMobile && menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = '10%'; + chatContainer.style.paddingLeft = '10%'; + } + // If menu is closed and we're on desktop, ensure default desktop padding + else if (!isMobile && !menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = '20%'; + chatContainer.style.paddingLeft = '20%'; + } }); \ No newline at end of file diff --git a/api/static/public/icons/github.svg b/api/static/public/icons/github.svg new file mode 100644 index 00000000..97f58554 --- /dev/null +++ b/api/static/public/icons/github.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/static/public/icons/google.svg b/api/static/public/icons/google.svg new file mode 100644 index 00000000..fb354fdb --- /dev/null +++ b/api/static/public/icons/google.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/static/public/icons/logo.svg b/api/static/public/icons/logo.svg index 3ae7f5cd..60ebfd33 100644 --- a/api/static/public/icons/logo.svg +++ b/api/static/public/icons/logo.svg @@ -3,30 +3,25 @@ - - - - - - - - - + + + + + + + + + - + - - - - - diff --git a/api/static/public/icons/menu.svg b/api/static/public/icons/menu.svg deleted file mode 100644 index 99a5c97e..00000000 --- a/api/static/public/icons/menu.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/api/templates/chat.html b/api/templates/chat.html deleted file mode 100644 index 6f778f88..00000000 --- a/api/templates/chat.html +++ /dev/null @@ -1,105 +0,0 @@ - - - - - - - Chatbot Interface - - - - - -
- -
- -
- -

Text-to-SQL(Natural Language to SQL Generator)

-
- -
- - - -
-
-
-
-
-
-
    -
  • - -
  • -
  • - -
  • -
  • - -
  • -
-
- - - - -
-
-
-
-
-
- - - - \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 new file mode 100644 index 00000000..6856dec5 --- /dev/null +++ b/api/templates/chat.j2 @@ -0,0 +1,189 @@ + + + + + + + Chatbot Interface + + + + + + + + {% if is_authenticated and user_info %} + + + {% endif %} +
+ +
+ +
+ +

Text-to-SQL(Natural Language to SQL Generator)

+
+ +
+ + +
+ +
+
+
+
+
+
+ + + + +
+
+
+
+
+ +
+
+

Connect to Postgres

+ +
+ + +
+
+
+
+
+

Reset Session

+

Are you sure you want to reset the current session? This will clear all chat history and start a new conversation.

+
+ + +
+
+
+ {# Set authentication state for JS before loading chat.js #} + + + + + \ No newline at end of file diff --git a/api/utils.py b/api/utils.py index 60d8c7b1..65932588 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,11 +1,20 @@ +"""Utility functions for the text2sql API.""" + import json from typing import List, Tuple + from litellm import completion + from api.config import Config from api.constants import BENCHMARK -def generate_db_description(db_name: str, table_names: List[str], temperature: float = 0.5, - max_tokens: int = 150) -> str: + +def generate_db_description( + db_name: str, + table_names: List[str], + temperature: float = 0.5, + max_tokens: int = 150, +) -> str: """ Generates a short and concise description of a database. @@ -20,14 +29,14 @@ def generate_db_description(db_name: str, table_names: List[str], temperature: f """ if not isinstance(db_name, str): raise TypeError("database_name must be a string.") - + if not isinstance(table_names, list): raise TypeError("table_names must be a list of strings.") - + # Ensure all table names are strings if not all(isinstance(table, str) for table in table_names): raise ValueError("All items in table_names must be strings.") - + if not table_names: return f"{db_name} is a database with no tables." @@ -40,25 +49,38 @@ def generate_db_description(db_name: str, table_names: List[str], temperature: f tables_formatted = ", ".join(table_names[:-1]) + f", and {table_names[-1]}" prompt = ( - f"You are a helpful assistant. Generate a concise description of the database named '{db_name}' " - f"which contains the following tables: {tables_formatted}.\n\n" - f"Description:" + f"You are a helpful assistant. Generate a concise description of " + f"the database named '{db_name}' which contains the following tables: " + f"{tables_formatted}.\n\nDescription:" ) - - response = completion(model=Config.COMPLETION_MODEL, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} - ], - temperature=temperature, - max_tokens=max_tokens, - n=1, - stop=None, - ) - description = response.choices[0].message['content'] + + response = completion( + model=Config.COMPLETION_MODEL, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=temperature, + max_tokens=max_tokens, + n=1, + stop=None, + ) + description = response.choices[0].message["content"] return description + def llm_answer_validator(question: str, answer: str, expected_answer: str = None) -> str: + """ + Validate an answer using LLM. + + Args: + question: The original question + answer: The generated answer + expected_answer: The expected answer for comparison + + Returns: + JSON string with validation results + """ prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the Generated Answer (generated sql) addresses the Question based on the Expected Answer. @@ -76,18 +98,37 @@ def llm_answer_validator(question: str, answer: str, expected_answer: str = None Output Json format: {{"relevance_score": float, "explanation": "Your assessment here."}} """ - response = completion(model=Config.VALIDTOR_MODEL, - messages=[ - {"role": "system", "content": "You are a Validator assistant."}, - {"role": "user", "content": prompt.format(question=question, expected_answer=expected_answer, generated_answer=answer)} - ], - response_format={"type": "json_object"}, - - ) - validation_set = response.choices[0].message['content'].strip() + response = completion( + model=Config.VALIDATOR_MODEL, + messages=[ + {"role": "system", "content": "You are a Validator assistant."}, + { + "role": "user", + "content": prompt.format( + question=question, + expected_answer=expected_answer, + generated_answer=answer, + ), + }, + ], + response_format={"type": "json_object"}, + ) + validation_set = response.choices[0].message["content"].strip() return validation_set + def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[float, str]: + """ + Validate table relevance using LLM. + + Args: + question: The original question + answer: The generated answer + tables: List of available tables + + Returns: + Tuple of relevance score and explanation + """ prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the retrived Tables relevant to the question and supports the Generated Answer (generated sql). - The tables are with the following structure: @@ -106,19 +147,23 @@ def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[ Output Json format: {{"relevance_score": float, "explanation": "Your assessment here."}} """ - response = completion(model=Config.VALIDTOR_MODEL, - messages=[ - {"role": "system", "content": "You are a Validator assistant."}, - {"role": "user", "content": prompt.format(question=question, tables=tables, generated_answer=answer)} - ], - response_format={"type": "json_object"}, - ) - validation_set = response.choices[0].message['content'].strip() + response = completion( + model=Config.VALIDATOR_MODEL, + messages=[ + {"role": "system", "content": "You are a Validator assistant."}, + { + "role": "user", + "content": prompt.format(question=question, tables=tables, generated_answer=answer), + }, + ], + response_format={"type": "json_object"}, + ) + validation_set = response.choices[0].message["content"].strip() try: val_res = json.loads(validation_set) - score = val_res['relevance_score'] - explanation = val_res['explanation'] - except Exception as e: + score = val_res["relevance_score"] + explanation = val_res["explanation"] + except (json.JSONDecodeError, KeyError) as e: print(f"Error: {e}") score = 0.0 explanation = "Error: Unable to parse the response." @@ -138,8 +183,7 @@ def run_benchmark(): for data in benchmark_data: success, result = generate_db_description( - db_name=data['database'], - table_names=list(data['tables'].keys()) + db_name=data["database"], table_names=list(data["tables"].keys()) ) if success: @@ -147,4 +191,4 @@ def run_benchmark(): else: results.append(f"Error: {result}") - return results \ No newline at end of file + return results diff --git a/docs/postgres_loader.md b/docs/postgres_loader.md new file mode 100644 index 00000000..d4024fd1 --- /dev/null +++ b/docs/postgres_loader.md @@ -0,0 +1,240 @@ +# PostgreSQL Schema Loader + +This loader connects to a PostgreSQL database and extracts the complete schema information, including tables, columns, relationships, and constraints. The extracted schema is then loaded into a graph database for further analysis and query generation. + +## Features + +- **Complete Schema Extraction**: Retrieves all tables, columns, data types, constraints, and relationships +- **Foreign Key Relationships**: Automatically discovers and maps foreign key relationships between tables +- **Column Metadata**: Extracts column comments, default values, nullability, and key types +- **Batch Processing**: Efficiently processes large schemas with progress tracking +- **Error Handling**: Robust error handling for connection issues and malformed schemas + +## Installation + +{% capture shell_0 %} +poetry add psycopg2-binary +{% endcapture %} + +{% capture shell_1 %} +pip install psycopg2-binary +{% endcapture %} + +{% include code_tabs.html id="install_tabs" shell=shell_0 shell2=shell_1 %} + +## Usage + +### Basic Usage + +{% capture python_0 %} +from api.loaders.postgres_loader import PostgreSQLLoader + +# Connection URL format: postgresql://username:password@host:port/database +connection_url = "postgresql://postgres:password@localhost:5432/mydatabase" +graph_id = "my_schema_graph" + +success, message = PostgreSQLLoader.load(graph_id, connection_url) + +if success: + print(f"Schema loaded successfully: {message}") +else: + print(f"Failed to load schema: {message}") +{% endcapture %} + +{% capture javascript_0 %} +import { PostgreSQLLoader } from 'your-pkg'; + +const connectionUrl = "postgresql://postgres:password@localhost:5432/mydatabase"; +const graphId = "my_schema_graph"; + +const [success, message] = await PostgreSQLLoader.load(graphId, connectionUrl); +if (success) { + console.log(`Schema loaded successfully: ${message}`); +} else { + console.log(`Failed to load schema: ${message}`); +} +{% endcapture %} + +{% capture java_0 %} +String connectionUrl = "postgresql://postgres:password@localhost:5432/mydatabase"; +String graphId = "my_schema_graph"; +Pair result = PostgreSQLLoader.load(graphId, connectionUrl); +if (result.getLeft()) { + System.out.println("Schema loaded successfully: " + result.getRight()); +} else { + System.out.println("Failed to load schema: " + result.getRight()); +} +{% endcapture %} + +{% capture rust_0 %} +let connection_url = "postgresql://postgres:password@localhost:5432/mydatabase"; +let graph_id = "my_schema_graph"; +let (success, message) = postgresql_loader::load(graph_id, connection_url)?; +if success { + println!("Schema loaded successfully: {}", message); +} else { + println!("Failed to load schema: {}", message); +} +{% endcapture %} + +{% include code_tabs.html id="basic_usage_tabs" python=python_0 javascript=javascript_0 java=java_0 rust=rust_0 %} + +### Connection URL Format + +``` +postgresql://[username[:password]@][host[:port]][/database][?options] +``` + +**Examples:** +- `postgresql://postgres:password@localhost:5432/mydatabase` +- `postgresql://user:pass@example.com:5432/production_db` +- `postgresql://postgres@127.0.0.1/testdb` + +### Integration with Graph Database + +{% capture python_1 %} +from api.loaders.postgres_loader import PostgreSQLLoader +from api.extensions import db + +# Load PostgreSQL schema into graph +graph_id = "customer_db_schema" +connection_url = "postgresql://postgres:password@localhost:5432/customers" + +success, message = PostgreSQLLoader.load(graph_id, connection_url) + +if success: + # The schema is now available in the graph database + graph = db.select_graph(graph_id) + + # Query for all tables + result = graph.query("MATCH (t:Table) RETURN t.name") + print("Tables:", [record[0] for record in result.result_set]) +{% endcapture %} + +{% capture javascript_1 %} +import { PostgreSQLLoader, db } from 'your-pkg'; + +const graphId = "customer_db_schema"; +const connectionUrl = "postgresql://postgres:password@localhost:5432/customers"; + +const [success, message] = await PostgreSQLLoader.load(graphId, connectionUrl); +if (success) { + const graph = db.selectGraph(graphId); + const result = await graph.query("MATCH (t:Table) RETURN t.name"); + console.log("Tables:", result.map(r => r[0])); +} +{% endcapture %} + +{% capture java_1 %} +String graphId = "customer_db_schema"; +String connectionUrl = "postgresql://postgres:password@localhost:5432/customers"; +Pair result = PostgreSQLLoader.load(graphId, connectionUrl); +if (result.getLeft()) { + Graph graph = db.selectGraph(graphId); + ResultSet rs = graph.query("MATCH (t:Table) RETURN t.name"); + // Print table names + for (Record record : rs) { + System.out.println(record.get(0)); + } +} +{% endcapture %} + +{% capture rust_1 %} +let graph_id = "customer_db_schema"; +let connection_url = "postgresql://postgres:password@localhost:5432/customers"; +let (success, message) = postgresql_loader::load(graph_id, connection_url)?; +if success { + let graph = db.select_graph(graph_id); + let result = graph.query("MATCH (t:Table) RETURN t.name")?; + println!("Tables: {:?}", result.iter().map(|r| &r[0]).collect::>()); +} +{% endcapture %} + +{% include code_tabs.html id="integration_tabs" python=python_1 javascript=javascript_1 java=java_1 rust=rust_1 %} + +## Schema Structure + +The loader extracts the following information: + +### Tables +- Table name +- Table description/comment +- Column information +- Foreign key relationships + +### Columns +- Column name +- Data type +- Nullability +- Default values +- Key type (PRIMARY KEY, FOREIGN KEY, or NONE) +- Column descriptions/comments + +### Relationships +- Foreign key constraints +- Referenced tables and columns +- Constraint names and metadata + +## Graph Database Schema + +The extracted schema is stored in the graph database with the following node types: + +- **Database**: Represents the source database +- **Table**: Represents database tables +- **Column**: Represents table columns + +And the following relationship types: + +- **BELONGS_TO**: Connects columns to their tables +- **REFERENCES**: Connects foreign key columns to their referenced columns + +## Error Handling + +The loader handles various error conditions: + +- **Connection Errors**: Invalid connection URLs or database unavailability +- **Permission Errors**: Insufficient database permissions +- **Schema Errors**: Invalid or corrupt schema information +- **Graph Errors**: Issues with graph database operations + +## Example Output + +{% capture shell_2 %} +Extracting table information: 100%|██████████| 15/15 [00:02<00:00, 7.50it/s] +Creating Graph Table Nodes: 100%|██████████| 15/15 [00:05<00:00, 2.80it/s] +Creating embeddings for customers columns: 100%|██████████| 2/2 [00:01<00:00, 1.20it/s] +Creating Graph Columns for customers: 100%|██████████| 8/8 [00:03<00:00, 2.40it/s] +... +Creating Graph Table Relationships: 100%|██████████| 12/12 [00:02<00:00, 5.20it/s] + +PostgreSQL schema loaded successfully. Found 15 tables. +{% endcapture %} + +{% include code_tabs.html id="output_tabs" shell=shell_2 %} + +## Requirements + +- Python 3.12+ +- psycopg2-binary +- Access to a PostgreSQL database +- Existing graph database infrastructure (FalkorDB) + +## Limitations + +- Currently only supports PostgreSQL databases +- Extracts schema from the 'public' schema only +- Requires read permissions on information_schema and pg_* system tables +- Large schemas may take time to process due to embedding generation + +## Troubleshooting + +### Common Issues + +1. **Connection Failed**: Verify the connection URL format and database credentials +2. **Permission Denied**: Ensure the database user has read access to system tables +3. **Schema Not Found**: Check that tables exist in the 'public' schema +4. **Graph Database Error**: Verify that the graph database is running and accessible + +### Debug Mode + +For debugging, you can enable verbose output by modifying the loader to print additional information about the extraction process. diff --git a/examples/crm.sql b/examples/crm.sql new file mode 100644 index 00000000..1c05fd6c --- /dev/null +++ b/examples/crm.sql @@ -0,0 +1,611 @@ +-- SQL Script 1 (Extended): Table Creation (DDL) with Comments +-- This script creates the tables for your CRM database and adds descriptions for each table and column. + +-- Drop existing tables to start fresh +DROP TABLE IF EXISTS SalesOrderItems, SalesOrders, Invoices, Payments, Products, ProductCategories, Leads, Opportunities, Contacts, Customers, Campaigns, CampaignMembers, Tasks, Notes, Attachments, SupportTickets, TicketComments, Users, Roles, UserRoles CASCADE; + +-- Roles for access control +CREATE TABLE Roles ( + RoleID SERIAL PRIMARY KEY, + RoleName VARCHAR(50) UNIQUE NOT NULL +); +COMMENT ON TABLE Roles IS 'Defines user roles for access control within the CRM (e.g., Admin, Sales Manager).'; +COMMENT ON COLUMN Roles.RoleID IS 'Unique identifier for the role.'; +COMMENT ON COLUMN Roles.RoleName IS 'Name of the role (e.g., "Admin", "Sales Representative").'; + +-- Users of the CRM system +CREATE TABLE Users ( + UserID SERIAL PRIMARY KEY, + Username VARCHAR(50) UNIQUE NOT NULL, + PasswordHash VARCHAR(255) NOT NULL, + Email VARCHAR(100) UNIQUE NOT NULL, + FirstName VARCHAR(50), + LastName VARCHAR(50), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Users IS 'Stores information about users who can log in to the CRM system.'; +COMMENT ON COLUMN Users.UserID IS 'Unique identifier for the user.'; +COMMENT ON COLUMN Users.Username IS 'The username for logging in.'; +COMMENT ON COLUMN Users.PasswordHash IS 'Hashed password for security.'; +COMMENT ON COLUMN Users.Email IS 'The user''s email address.'; +COMMENT ON COLUMN Users.FirstName IS 'The user''s first name.'; +COMMENT ON COLUMN Users.LastName IS 'The user''s last name.'; +COMMENT ON COLUMN Users.CreatedAt IS 'Timestamp when the user account was created.'; + +-- Junction table for Users and Roles +CREATE TABLE UserRoles ( + UserID INT REFERENCES Users(UserID), + RoleID INT REFERENCES Roles(RoleID), + PRIMARY KEY (UserID, RoleID) +); +COMMENT ON TABLE UserRoles IS 'Maps users to their assigned roles, supporting many-to-many relationships.'; +COMMENT ON COLUMN UserRoles.UserID IS 'Foreign key referencing the Users table.'; +COMMENT ON COLUMN UserRoles.RoleID IS 'Foreign key referencing the Roles table.'; + +-- Customer accounts +CREATE TABLE Customers ( + CustomerID SERIAL PRIMARY KEY, + CustomerName VARCHAR(100) NOT NULL, + Industry VARCHAR(50), + Website VARCHAR(100), + Phone VARCHAR(20), + Address VARCHAR(255), + City VARCHAR(50), + State VARCHAR(50), + ZipCode VARCHAR(20), + Country VARCHAR(50), + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Customers IS 'Represents customer accounts or companies.'; +COMMENT ON COLUMN Customers.CustomerID IS 'Unique identifier for the customer.'; +COMMENT ON COLUMN Customers.CustomerName IS 'The name of the customer company.'; +COMMENT ON COLUMN Customers.Industry IS 'The industry the customer belongs to.'; +COMMENT ON COLUMN Customers.Website IS 'The customer''s official website.'; +COMMENT ON COLUMN Customers.Phone IS 'The customer''s primary phone number.'; +COMMENT ON COLUMN Customers.Address IS 'The customer''s physical address.'; +COMMENT ON COLUMN Customers.City IS 'The city part of the address.'; +COMMENT ON COLUMN Customers.State IS 'The state or province part of the address.'; +COMMENT ON COLUMN Customers.ZipCode IS 'The postal or zip code.'; +COMMENT ON COLUMN Customers.Country IS 'The country part of the address.'; +COMMENT ON COLUMN Customers.AssignedTo IS 'The user (sales representative) assigned to this customer account.'; +COMMENT ON COLUMN Customers.CreatedAt IS 'Timestamp when the customer was added.'; + +-- Individual contacts associated with customers +CREATE TABLE Contacts ( + ContactID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + FirstName VARCHAR(50) NOT NULL, + LastName VARCHAR(50) NOT NULL, + Email VARCHAR(100) UNIQUE, + Phone VARCHAR(20), + JobTitle VARCHAR(50), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Contacts IS 'Stores information about individual contacts associated with customer accounts.'; +COMMENT ON COLUMN Contacts.ContactID IS 'Unique identifier for the contact.'; +COMMENT ON COLUMN Contacts.CustomerID IS 'Foreign key linking the contact to a customer account.'; +COMMENT ON COLUMN Contacts.FirstName IS 'The contact''s first name.'; +COMMENT ON COLUMN Contacts.LastName IS 'The contact''s last name.'; +COMMENT ON COLUMN Contacts.Email IS 'The contact''s email address.'; +COMMENT ON COLUMN Contacts.Phone IS 'The contact''s phone number.'; +COMMENT ON COLUMN Contacts.JobTitle IS 'The contact''s job title or position.'; +COMMENT ON COLUMN Contacts.CreatedAt IS 'Timestamp when the contact was created.'; + +-- Potential sales leads +CREATE TABLE Leads ( + LeadID SERIAL PRIMARY KEY, + FirstName VARCHAR(50), + LastName VARCHAR(50), + Email VARCHAR(100), + Phone VARCHAR(20), + Company VARCHAR(100), + Status VARCHAR(50) DEFAULT 'New', + Source VARCHAR(50), + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Leads IS 'Represents potential customers or sales prospects (not yet qualified).'; +COMMENT ON COLUMN Leads.LeadID IS 'Unique identifier for the lead.'; +COMMENT ON COLUMN Leads.Status IS 'Current status of the lead (e.g., New, Contacted, Qualified, Lost).'; +COMMENT ON COLUMN Leads.Source IS 'The source from which the lead was generated (e.g., Website, Referral).'; +COMMENT ON COLUMN Leads.AssignedTo IS 'The user assigned to follow up with this lead.'; +COMMENT ON COLUMN Leads.CreatedAt IS 'Timestamp when the lead was created.'; + +-- Sales opportunities +CREATE TABLE Opportunities ( + OpportunityID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + OpportunityName VARCHAR(100) NOT NULL, + Stage VARCHAR(50) DEFAULT 'Prospecting', + Amount DECIMAL(12, 2), + CloseDate DATE, + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Opportunities IS 'Tracks qualified sales deals with potential revenue.'; +COMMENT ON COLUMN Opportunities.OpportunityID IS 'Unique identifier for the opportunity.'; +COMMENT ON COLUMN Opportunities.CustomerID IS 'Foreign key linking the opportunity to a customer account.'; +COMMENT ON COLUMN Opportunities.OpportunityName IS 'A descriptive name for the sales opportunity.'; +COMMENT ON COLUMN Opportunities.Stage IS 'The current stage in the sales pipeline (e.g., Prospecting, Proposal, Closed Won).'; +COMMENT ON COLUMN Opportunities.Amount IS 'The estimated value of the opportunity.'; +COMMENT ON COLUMN Opportunities.CloseDate IS 'The expected date the deal will close.'; +COMMENT ON COLUMN Opportunities.AssignedTo IS 'The user responsible for this opportunity.'; +COMMENT ON COLUMN Opportunities.CreatedAt IS 'Timestamp when the opportunity was created.'; + +-- Product categories +CREATE TABLE ProductCategories ( + CategoryID SERIAL PRIMARY KEY, + CategoryName VARCHAR(50) NOT NULL, + Description TEXT +); +COMMENT ON TABLE ProductCategories IS 'Used to group products into categories (e.g., Software, Hardware).'; +COMMENT ON COLUMN ProductCategories.CategoryID IS 'Unique identifier for the category.'; +COMMENT ON COLUMN ProductCategories.CategoryName IS 'Name of the product category.'; +COMMENT ON COLUMN ProductCategories.Description IS 'A brief description of the category.'; + +-- Products or services offered +CREATE TABLE Products ( + ProductID SERIAL PRIMARY KEY, + ProductName VARCHAR(100) NOT NULL, + CategoryID INT REFERENCES ProductCategories(CategoryID), + Description TEXT, + Price DECIMAL(10, 2) NOT NULL, + StockQuantity INT DEFAULT 0 +); +COMMENT ON TABLE Products IS 'Stores details of the products or services the company sells.'; +COMMENT ON COLUMN Products.ProductID IS 'Unique identifier for the product.'; +COMMENT ON COLUMN Products.ProductName IS 'Name of the product.'; +COMMENT ON COLUMN Products.CategoryID IS 'Foreign key linking the product to a category.'; +COMMENT ON COLUMN Products.Description IS 'Detailed description of the product.'; +COMMENT ON COLUMN Products.Price IS 'The unit price of the product.'; +COMMENT ON COLUMN Products.StockQuantity IS 'The quantity of the product available in stock.'; + +-- Sales orders +CREATE TABLE SalesOrders ( + OrderID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + OpportunityID INT REFERENCES Opportunities(OpportunityID), + OrderDate DATE NOT NULL, + Status VARCHAR(50) DEFAULT 'Pending', + TotalAmount DECIMAL(12, 2), + AssignedTo INT REFERENCES Users(UserID) +); +COMMENT ON TABLE SalesOrders IS 'Records of confirmed sales to customers.'; +COMMENT ON COLUMN SalesOrders.OrderID IS 'Unique identifier for the sales order.'; +COMMENT ON COLUMN SalesOrders.CustomerID IS 'Foreign key linking the order to a customer.'; +COMMENT ON COLUMN SalesOrders.OpportunityID IS 'Foreign key linking the order to the sales opportunity it came from.'; +COMMENT ON COLUMN SalesOrders.OrderDate IS 'The date the order was placed.'; +COMMENT ON COLUMN SalesOrders.Status IS 'The current status of the order (e.g., Pending, Shipped, Canceled).'; +COMMENT ON COLUMN SalesOrders.TotalAmount IS 'The total calculated amount for the order.'; +COMMENT ON COLUMN SalesOrders.AssignedTo IS 'The user who processed the order.'; + +-- Items within a sales order +CREATE TABLE SalesOrderItems ( + OrderItemID SERIAL PRIMARY KEY, + OrderID INT REFERENCES SalesOrders(OrderID) ON DELETE CASCADE, + ProductID INT REFERENCES Products(ProductID), + Quantity INT NOT NULL, + UnitPrice DECIMAL(10, 2) NOT NULL +); +COMMENT ON TABLE SalesOrderItems IS 'Line items for each product within a sales order.'; +COMMENT ON COLUMN SalesOrderItems.OrderItemID IS 'Unique identifier for the order item.'; +COMMENT ON COLUMN SalesOrderItems.OrderID IS 'Foreign key linking this item to a sales order.'; +COMMENT ON COLUMN SalesOrderItems.ProductID IS 'Foreign key linking to the product being ordered.'; +COMMENT ON COLUMN SalesOrderItems.Quantity IS 'The quantity of the product ordered.'; +COMMENT ON COLUMN SalesOrderItems.UnitPrice IS 'The price per unit at the time of sale.'; + +-- Invoices for sales +CREATE TABLE Invoices ( + InvoiceID SERIAL PRIMARY KEY, + OrderID INT REFERENCES SalesOrders(OrderID), + InvoiceDate DATE NOT NULL, + DueDate DATE, + TotalAmount DECIMAL(12, 2), + Status VARCHAR(50) DEFAULT 'Unpaid' +); +COMMENT ON TABLE Invoices IS 'Represents billing invoices sent to customers.'; +COMMENT ON COLUMN Invoices.InvoiceID IS 'Unique identifier for the invoice.'; +COMMENT ON COLUMN Invoices.OrderID IS 'Foreign key linking the invoice to a sales order.'; +COMMENT ON COLUMN Invoices.InvoiceDate IS 'The date the invoice was issued.'; +COMMENT ON COLUMN Invoices.DueDate IS 'The date the payment is due.'; +COMMENT ON COLUMN Invoices.TotalAmount IS 'The total amount due on the invoice.'; +COMMENT ON COLUMN Invoices.Status IS 'The payment status of the invoice (e.g., Unpaid, Paid, Overdue).'; + +-- Payment records +CREATE TABLE Payments ( + PaymentID SERIAL PRIMARY KEY, + InvoiceID INT REFERENCES Invoices(InvoiceID), + PaymentDate DATE NOT NULL, + Amount DECIMAL(12, 2), + PaymentMethod VARCHAR(50) +); +COMMENT ON TABLE Payments IS 'Tracks payments received from customers against invoices.'; +COMMENT ON COLUMN Payments.PaymentID IS 'Unique identifier for the payment.'; +COMMENT ON COLUMN Payments.InvoiceID IS 'Foreign key linking the payment to an invoice.'; +COMMENT ON COLUMN Payments.PaymentDate IS 'The date the payment was received.'; +COMMENT ON COLUMN Payments.Amount IS 'The amount that was paid.'; +COMMENT ON COLUMN Payments.PaymentMethod IS 'The method of payment (e.g., Credit Card, Bank Transfer).'; + +-- Marketing campaigns +CREATE TABLE Campaigns ( + CampaignID SERIAL PRIMARY KEY, + CampaignName VARCHAR(100) NOT NULL, + StartDate DATE, + EndDate DATE, + Budget DECIMAL(12, 2), + Status VARCHAR(50), + Owner INT REFERENCES Users(UserID) +); +COMMENT ON TABLE Campaigns IS 'Stores information about marketing campaigns.'; +COMMENT ON COLUMN Campaigns.CampaignID IS 'Unique identifier for the campaign.'; +COMMENT ON COLUMN Campaigns.CampaignName IS 'The name of the marketing campaign.'; +COMMENT ON COLUMN Campaigns.StartDate IS 'The start date of the campaign.'; +COMMENT ON COLUMN Campaigns.EndDate IS 'The end date of the campaign.'; +COMMENT ON COLUMN Campaigns.Budget IS 'The allocated budget for the campaign.'; +COMMENT ON COLUMN Campaigns.Status IS 'The current status of the campaign (e.g., Planned, Active, Completed).'; +COMMENT ON COLUMN Campaigns.Owner IS 'The user responsible for the campaign.'; + +-- Members of a marketing campaign (leads or contacts) +CREATE TABLE CampaignMembers ( + CampaignMemberID SERIAL PRIMARY KEY, + CampaignID INT REFERENCES Campaigns(CampaignID), + LeadID INT REFERENCES Leads(LeadID), + ContactID INT REFERENCES Contacts(ContactID), + Status VARCHAR(50) +); +COMMENT ON TABLE CampaignMembers IS 'Links leads and contacts to the marketing campaigns they are a part of.'; +COMMENT ON COLUMN CampaignMembers.CampaignMemberID IS 'Unique identifier for the campaign member record.'; +COMMENT ON COLUMN CampaignMembers.CampaignID IS 'Foreign key linking to the campaign.'; +COMMENT ON COLUMN CampaignMembers.LeadID IS 'Foreign key linking to a lead (if the member is a lead).'; +COMMENT ON COLUMN CampaignMembers.ContactID IS 'Foreign key linking to a contact (if the member is a contact).'; +COMMENT ON COLUMN CampaignMembers.Status IS 'The status of the member in the campaign (e.g., Sent, Responded).'; + +-- Tasks for users +CREATE TABLE Tasks ( + TaskID SERIAL PRIMARY KEY, + Title VARCHAR(100) NOT NULL, + Description TEXT, + DueDate DATE, + Status VARCHAR(50) DEFAULT 'Not Started', + Priority VARCHAR(20) DEFAULT 'Normal', + AssignedTo INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT +); +COMMENT ON TABLE Tasks IS 'Tracks tasks or to-do items for CRM users.'; +COMMENT ON COLUMN Tasks.TaskID IS 'Unique identifier for the task.'; +COMMENT ON COLUMN Tasks.Title IS 'A short title for the task.'; +COMMENT ON COLUMN Tasks.Description IS 'A detailed description of the task.'; +COMMENT ON COLUMN Tasks.DueDate IS 'The date the task is due to be completed.'; +COMMENT ON COLUMN Tasks.Status IS 'The current status of the task (e.g., Not Started, In Progress, Completed).'; +COMMENT ON COLUMN Tasks.Priority IS 'The priority level of the task (e.g., Low, Normal, High).'; +COMMENT ON COLUMN Tasks.AssignedTo IS 'The user the task is assigned to.'; +COMMENT ON COLUMN Tasks.RelatedToEntity IS 'The type of record this task is related to (e.g., ''Lead'', ''Opportunity'').'; +COMMENT ON COLUMN Tasks.RelatedToID IS 'The ID of the related record.'; + +-- Notes related to various records +CREATE TABLE Notes ( + NoteID SERIAL PRIMARY KEY, + Content TEXT NOT NULL, + CreatedBy INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT, + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Notes IS 'Allows users to add notes to various records (e.g., contacts, opportunities).'; +COMMENT ON COLUMN Notes.NoteID IS 'Unique identifier for the note.'; +COMMENT on COLUMN Notes.Content IS 'The text content of the note.'; +COMMENT ON COLUMN Notes.CreatedBy IS 'The user who created the note.'; +COMMENT ON COLUMN Notes.RelatedToEntity IS 'The type of record this note is related to (e.g., ''Contact'', ''Customer'').'; +COMMENT ON COLUMN Notes.RelatedToID IS 'The ID of the related record.'; +COMMENT ON COLUMN Notes.CreatedAt IS 'Timestamp when the note was created.'; + +-- File attachments +CREATE TABLE Attachments ( + AttachmentID SERIAL PRIMARY KEY, + FileName VARCHAR(255) NOT NULL, + FilePath VARCHAR(255) NOT NULL, + FileSize INT, + FileType VARCHAR(100), + UploadedBy INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT, + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Attachments IS 'Stores metadata about files attached to records in the CRM.'; +COMMENT ON COLUMN Attachments.AttachmentID IS 'Unique identifier for the attachment.'; +COMMENT ON COLUMN Attachments.FileName IS 'The original name of the uploaded file.'; +COMMENT ON COLUMN Attachments.FilePath IS 'The path where the file is stored on the server.'; +COMMENT ON COLUMN Attachments.FileSize IS 'The size of the file in bytes.'; +COMMENT ON COLUMN Attachments.FileType IS 'The MIME type of the file (e.g., ''application/pdf'').'; +COMMENT ON COLUMN Attachments.UploadedBy IS 'The user who uploaded the file.'; +COMMENT ON COLUMN Attachments.RelatedToEntity IS 'The type of record this attachment is related to.'; +COMMENT ON COLUMN Attachments.RelatedToID IS 'The ID of the related record.'; +COMMENT ON COLUMN Attachments.CreatedAt IS 'Timestamp when the file was uploaded.'; + +-- Customer support tickets +CREATE TABLE SupportTickets ( + TicketID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + ContactID INT REFERENCES Contacts(ContactID), + Subject VARCHAR(255) NOT NULL, + Description TEXT, + Status VARCHAR(50) DEFAULT 'Open', + Priority VARCHAR(20) DEFAULT 'Normal', + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE SupportTickets IS 'Tracks customer service and support requests.'; +COMMENT ON COLUMN SupportTickets.TicketID IS 'Unique identifier for the support ticket.'; +COMMENT ON COLUMN SupportTickets.CustomerID IS 'Foreign key linking the ticket to a customer.'; +COMMENT ON COLUMN SupportTickets.ContactID IS 'Foreign key linking the ticket to a specific contact.'; +COMMENT ON COLUMN SupportTickets.Subject IS 'A brief summary of the support issue.'; +COMMENT ON COLUMN SupportTickets.Description IS 'A detailed description of the issue.'; +COMMENT ON COLUMN SupportTickets.Status IS 'The current status of the ticket (e.g., Open, In Progress, Resolved).'; +COMMENT ON COLUMN SupportTickets.Priority IS 'The priority of the ticket (e.g., Low, Normal, High).'; +COMMENT ON COLUMN SupportTickets.AssignedTo IS 'The support agent the ticket is assigned to.'; +COMMENT ON COLUMN SupportTickets.CreatedAt IS 'Timestamp when the ticket was created.'; + +-- Comments on support tickets +CREATE TABLE TicketComments ( + CommentID SERIAL PRIMARY KEY, + TicketID INT REFERENCES SupportTickets(TicketID) ON DELETE CASCADE, + Comment TEXT NOT NULL, + CreatedBy INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE TicketComments IS 'Stores comments and updates related to a support ticket.'; +COMMENT ON COLUMN TicketComments.CommentID IS 'Unique identifier for the comment.'; +COMMENT ON COLUMN TicketComments.TicketID IS 'Foreign key linking the comment to a support ticket.'; +COMMENT ON COLUMN TicketComments.Comment IS 'The text content of the comment.'; +COMMENT ON COLUMN TicketComments.CreatedBy IS 'The user who added the comment.'; +COMMENT ON COLUMN TicketComments.CreatedAt IS 'Timestamp when the comment was added.'; + + +-- SQL Script 2: Data Insertion (DML) +-- This script populates the tables with sample data. + +-- Insert Roles +INSERT INTO Roles (RoleName) VALUES ('Admin'), ('Sales Manager'), ('Sales Representative'), ('Support Agent'); + +-- Insert Users +INSERT INTO Users (Username, PasswordHash, Email, FirstName, LastName) VALUES +('admin', 'hashed_password', 'admin@example.com', 'Admin', 'User'), +('sales_manager', 'hashed_password', 'manager@example.com', 'John', 'Doe'), +('sales_rep1', 'hashed_password', 'rep1@example.com', 'Jane', 'Smith'), +('sales_rep2', 'hashed_password', 'rep2@example.com', 'Peter', 'Jones'), +('support_agent1', 'hashed_password', 'support1@example.com', 'Mary', 'Williams'); + +-- Assign Roles to Users +INSERT INTO UserRoles (UserID, RoleID) VALUES +(1, 1), (2, 2), (3, 3), (4, 3), (5, 4); + +-- Insert Customers +INSERT INTO Customers (CustomerName, Industry, Website, Phone, Address, City, State, ZipCode, Country, AssignedTo) VALUES +('ABC Corporation', 'Technology', 'http://www.abccorp.com', '123-456-7890', '123 Tech Park', 'Techville', 'CA', '90210', 'USA', 3), +('Innovate Inc.', 'Software', 'http://www.innovate.com', '234-567-8901', '456 Innovation Dr', 'Devtown', 'TX', '75001', 'USA', 4), +('Global Solutions', 'Consulting', 'http://www.globalsolutions.com', '345-678-9012', '789 Global Ave', 'Businesston', 'NY', '10001', 'USA', 3), +('Data Dynamics', 'Analytics', 'http://www.datadynamics.com', '456-123-7890', '789 Data Dr', 'Metropolis', 'IL', '60601', 'USA', 4), +('Synergy Solutions', 'HR', 'http://www.synergysolutions.com', '789-456-1230', '101 Synergy Blvd', 'Union City', 'NJ', '07087', 'USA', 3); + +-- Insert Contacts +INSERT INTO Contacts (CustomerID, FirstName, LastName, Email, Phone, JobTitle) VALUES +(1, 'Alice', 'Wonder', 'alice.wonder@abccorp.com', '123-456-7891', 'CTO'), +(1, 'Bob', 'Builder', 'bob.builder@abccorp.com', '123-456-7892', 'Project Manager'), +(2, 'Charlie', 'Chocolate', 'charlie.chocolate@innovate.com', '234-567-8902', 'CEO'), +(3, 'Diana', 'Prince', 'diana.prince@globalsolutions.com', '345-678-9013', 'Consultant'), +(4, 'Leo', 'Lytics', 'leo.lytics@datadynamics.com', '456-123-7891', 'Data Scientist'), +(5, 'Hannah', 'Resources', 'hannah.r@synergysolutions.com', '789-456-1231', 'HR Manager'); + +-- Insert Leads +INSERT INTO Leads (FirstName, LastName, Email, Phone, Company, Status, Source, AssignedTo) VALUES +('Eve', 'Apple', 'eve.apple@email.com', '456-789-0123', 'Future Gadgets', 'Qualified', 'Website', 3), +('Frank', 'Stein', 'frank.stein@email.com', '567-890-1234', 'Monster Corp', 'New', 'Referral', 4), +('Grace', 'Hopper', 'grace.hopper@email.com', '678-901-2345', 'Cobol Inc.', 'Contacted', 'Cold Call', 3), +('Ivy', 'Green', 'ivy.g@webmail.com', '890-123-4567', 'Eco Systems', 'New', 'Trade Show', 4), +('Jack', 'Nimble', 'jack.n@fastmail.com', '901-234-5678', 'Quick Corp', 'Qualified', 'Website', 3); + +-- Insert Opportunities +INSERT INTO Opportunities (CustomerID, OpportunityName, Stage, Amount, CloseDate, AssignedTo) VALUES +(1, 'ABC Corp Website Redesign', 'Proposal', 50000.00, '2025-08-30', 3), +(2, 'Innovate Inc. Mobile App', 'Qualification', 75000.00, '2025-09-15', 4), +(3, 'Global Solutions IT Consulting', 'Negotiation', 120000.00, '2025-08-20', 3), +(4, 'Analytics Platform Subscription', 'Proposal', 90000.00, '2025-09-30', 4), +(5, 'HR Software Implementation', 'Prospecting', 65000.00, '2025-10-25', 3); + +-- Insert Product Categories +INSERT INTO ProductCategories (CategoryName, Description) VALUES +('Software', 'Business and productivity software'), +('Hardware', 'Computer hardware and peripherals'), +('Services', 'Consulting and support services'); + +-- Insert Products +INSERT INTO Products (ProductName, CategoryID, Description, Price, StockQuantity) VALUES +('CRM Pro', 1, 'Advanced CRM Software Suite', 1500.00, 100), +('Office Laptop Model X', 2, 'High-performance laptop for business', 1200.00, 50), +('IT Support Package', 3, '24/7 IT support services', 300.00, 200), +('Analytics Dashboard Pro', 1, 'Advanced analytics dashboard', 2500.00, 75), +('Ergonomic Office Chair', 2, 'Comfortable chair for long hours', 350.00, 150); + +-- Insert Sales Orders +INSERT INTO SalesOrders (CustomerID, OpportunityID, OrderDate, Status, TotalAmount, AssignedTo) VALUES +(1, 1, '2025-07-20', 'Shipped', 1500.00, 3), +(2, 2, '2025-07-22', 'Pending', 2400.00, 4), +(3, 3, '2025-07-24', 'Delivered', 300.00, 3), +(4, 4, '2025-07-25', 'Pending', 2500.00, 4); + +-- Insert Sales Order Items +INSERT INTO SalesOrderItems (OrderID, ProductID, Quantity, UnitPrice) VALUES +(1, 1, 1, 1500.00), +(2, 2, 2, 1200.00), +(3, 3, 1, 300.00), +(4, 4, 1, 2500.00); + +-- Insert Invoices +INSERT INTO Invoices (OrderID, InvoiceDate, DueDate, TotalAmount, Status) VALUES +(1, '2025-07-21', '2025-08-20', 1500.00, 'Paid'), +(2, '2025-07-23', '2025-08-22', 2400.00, 'Unpaid'), +(3, '2025-07-24', '2025-08-23', 300.00, 'Paid'), +(4, '2025-07-25', '2025-08-24', 2500.00, 'Unpaid'); + +-- Insert Payments +INSERT INTO Payments (InvoiceID, PaymentDate, Amount, PaymentMethod) VALUES +(1, '2025-07-25', 1500.00, 'Credit Card'), +(3, '2025-07-25', 300.00, 'Bank Transfer'); + +-- Insert Campaigns +INSERT INTO Campaigns (CampaignName, StartDate, EndDate, Budget, Status, Owner) VALUES +('Summer Sale 2025', '2025-06-01', '2025-08-31', 10000.00, 'Active', 2), +('Q4 Product Launch', '2025-10-01', '2025-12-31', 25000.00, 'Planned', 2); + +-- Insert Campaign Members +INSERT INTO CampaignMembers (CampaignID, LeadID, Status) VALUES +(1, 1, 'Responded'), +(1, 2, 'Sent'), +(1, 4, 'Sent'); +INSERT INTO CampaignMembers (CampaignID, ContactID, Status) VALUES +(1, 4, 'Sent'), +(1, 5, 'Responded'); + +-- Insert Tasks +INSERT INTO Tasks (Title, Description, DueDate, Status, Priority, AssignedTo, RelatedToEntity, RelatedToID) VALUES +('Follow up with ABC Corp', 'Discuss proposal details', '2025-08-01', 'In Progress', 'High', 3, 'Opportunity', 1), +('Prepare demo for Innovate Inc.', 'Customize demo for their needs', '2025-08-05', 'Not Started', 'Normal', 4, 'Opportunity', 2), +('Send updated proposal to Global Solutions', 'Include new service terms', '2025-07-28', 'Completed', 'High', 3, 'Opportunity', 3), +('Schedule initial call with Synergy Solutions', 'Discuss HR software needs', '2025-08-02', 'Not Started', 'Normal', 3, 'Customer', 5); + +-- Insert Notes +INSERT INTO Notes (Content, CreatedBy, RelatedToEntity, RelatedToID) VALUES +('Alice is very interested in the mobile integration features.', 3, 'Contact', 1), +('Lead from the tech conference last week.', 4, 'Lead', 2), +('Customer is looking for a cloud-based solution.', 4, 'Opportunity', 4), +('Met Ivy at the GreenTech expo. Promising lead.', 4, 'Lead', 4); + +-- Insert Attachments +INSERT INTO Attachments (FileName, FilePath, FileSize, FileType, UploadedBy, RelatedToEntity, RelatedToID) VALUES +('proposal_v1.pdf', '/attachments/proposal_v1.pdf', 102400, 'application/pdf', 3, 'Opportunity', 1), +('analytics_brochure.pdf', '/attachments/analytics_brochure.pdf', 256000, 'application/pdf', 4, 'Opportunity', 4); + +-- Insert Support Tickets +INSERT INTO SupportTickets (CustomerID, ContactID, Subject, Description, Status, Priority, AssignedTo) VALUES +(1, 1, 'Cannot login to portal', 'User Alice Wonder is unable to access the customer portal.', 'Resolved', 'High', 5), +(2, 3, 'Billing question', 'Question about the last invoice.', 'In Progress', 'Normal', 5), +(3, 4, 'Feature Request: Dark Mode', 'Requesting dark mode for the user dashboard.', 'Open', 'Low', 5), +(1, 2, 'Integration issue with calendar', 'Tasks are not syncing with Google Calendar.', 'In Progress', 'High', 5); + +-- Insert Ticket Comments +INSERT INTO TicketComments (TicketID, Comment, CreatedBy) VALUES +(1, 'Have reset the password. Please ask the user to try again.', 5), +(1, 'User confirmed they can now log in. Closing the ticket.', 5), +(2, 'Checking API logs for sync errors.', 5), +(3, 'Feature has been added to the development backlog.', 5); + +-- SQL Script 3: Insert More Demo Data (DML) +-- This script adds more sample data to the CRM database. +-- Run this script AFTER running 1_create_tables.sql and 2_insert_data.sql. + +-- Insert more Customers (starting from CustomerID 6) +INSERT INTO Customers (CustomerName, Industry, Website, Phone, Address, City, State, ZipCode, Country, AssignedTo) VALUES +('Quantum Innovations', 'R&D', 'http://www.quantuminnovate.com', '555-0101', '100 Research Pkwy', 'Quantumville', 'MA', '02139', 'USA', 3), +('HealthFirst Medical', 'Healthcare', 'http://www.healthfirst.com', '555-0102', '200 Health Blvd', 'Wellnesston', 'FL', '33101', 'USA', 4), +('GreenScape Solutions', 'Environmental', 'http://www.greenscape.com', '555-0103', '300 Nature Way', 'Ecoville', 'OR', '97201', 'USA', 3), +('Pinnacle Finance', 'Finance', 'http://www.pinnaclefinance.com', '555-0104', '400 Wall St', 'Financeton', 'NY', '10005', 'USA', 4), +('Creative Minds Agency', 'Marketing', 'http://www.creativeminds.com', '555-0105', '500 Ad Ave', 'Creator City', 'CA', '90028', 'USA', 3); + +-- Insert more Contacts (starting from ContactID 7) +-- Assuming CustomerIDs 6-10 were just created +INSERT INTO Contacts (CustomerID, FirstName, LastName, Email, Phone, JobTitle) VALUES +(6, 'Quentin', 'Physics', 'q.physics@quantuminnovate.com', '555-0101-1', 'Lead Scientist'), +(7, 'Helen', 'Healer', 'h.healer@healthfirst.com', '555-0102-1', 'Hospital Administrator'), +(7, 'Marcus', 'Welby', 'm.welby@healthfirst.com', '555-0102-2', 'Chief of Medicine'), +(8, 'Gary', 'Gardener', 'g.gardener@greenscape.com', '555-0103-1', 'CEO'), +(9, 'Fiona', 'Funds', 'f.funds@pinnaclefinance.com', '555-0104-1', 'Investment Banker'), +(10, 'Chris', 'Creative', 'c.creative@creativeminds.com', '555-0105-1', 'Art Director'), +(1, 'Carol', 'Client', 'c.client@abccorp.com', '123-456-7893', 'IT Director'); -- Contact for existing customer + +-- Insert more Leads (starting from LeadID 6) +INSERT INTO Leads (FirstName, LastName, Email, Phone, Company, Status, Source, AssignedTo) VALUES +('Ken', 'Knowledge', 'ken.k@university.edu', '555-0201', 'State University', 'Contacted', 'Referral', 4), +('Laura', 'Legal', 'laura.l@lawfirm.com', '555-0202', 'Law & Order LLC', 'New', 'Website', 3), +('Mike', 'Mechanic', 'mike.m@autoshop.com', '555-0203', 'Auto Fixers', 'Lost', 'Cold Call', 4), +('Nancy', 'Nurse', 'nancy.n@clinic.com', '555-0204', 'Community Clinic', 'Qualified', 'Trade Show', 3), +('Oscar', 'Organizer', 'oscar.o@events.com', '555-0205', 'Events R Us', 'New', 'Website', 4); + +-- Insert more Opportunities (starting from OpportunityID 6) +-- Assuming CustomerIDs 6-10 were just created +INSERT INTO Opportunities (CustomerID, OpportunityName, Stage, Amount, CloseDate, AssignedTo) VALUES +(6, 'Quantum Computing Simulation Software', 'Qualification', 250000.00, '2025-11-15', 3), +(7, 'Patient Management System Upgrade', 'Proposal', 180000.00, '2025-12-01', 4), +(8, 'Environmental Impact Reporting Tool', 'Negotiation', 75000.00, '2025-10-30', 3), +(9, 'Wealth Management Platform', 'Closed Won', 300000.00, '2025-07-25', 4), +(10, 'Digital Marketing Campaign Analytics', 'Prospecting', 45000.00, '2025-11-20', 3); + +-- Insert a new Product Category first +INSERT INTO ProductCategories (CategoryName, Description) VALUES +('Cloud Solutions', 'Cloud-based infrastructure and platforms'); -- This will be CategoryID 4 + +-- Insert more Products (starting from ProductID 6) +INSERT INTO Products (ProductName, CategoryID, Description, Price, StockQuantity) VALUES +('Wealth Management Suite', 1, 'Comprehensive software for financial advisors', 5000.00, 50), +('Patient Record System', 1, 'EHR system for clinics and hospitals', 4500.00, 80), +('Cloud Storage - 10TB Plan', 4, '10TB of enterprise cloud storage', 1000.00, 500); + +-- Insert more Sales Orders (starting from OrderID 5) +-- For the 'Closed Won' opportunity (ID 9) +INSERT INTO SalesOrders (CustomerID, OpportunityID, OrderDate, Status, TotalAmount, AssignedTo) VALUES +(9, 9, '2025-07-26', 'Delivered', 5000.00, 4); + +-- Insert more Sales Order Items (for OrderID 5) +INSERT INTO SalesOrderItems (OrderID, ProductID, Quantity, UnitPrice) VALUES +(5, 6, 1, 5000.00); -- Wealth Management Suite (ProductID 6) + +-- Insert more Invoices (starting from InvoiceID 5) +INSERT INTO Invoices (OrderID, InvoiceDate, DueDate, TotalAmount, Status) VALUES +(5, '2025-07-26', '2025-08-25', 5000.00, 'Paid'); + +-- Insert more Payments (starting from PaymentID 3) +INSERT INTO Payments (InvoiceID, PaymentDate, Amount, PaymentMethod) VALUES +(2, '2025-07-25', 2400.00, 'Bank Transfer'), -- Payment for an existing unpaid invoice +(5, '2025-07-26', 5000.00, 'Credit Card'); + +-- Insert a new Campaign (starting from CampaignID 3) +INSERT INTO Campaigns (CampaignName, StartDate, EndDate, Budget, Status, Owner) VALUES +('Healthcare Solutions Webinar', '2025-09-01', '2025-09-30', 7500.00, 'Planned', 2); + +-- Insert more Campaign Members +INSERT INTO CampaignMembers (CampaignID, LeadID, Status) VALUES +(3, 9, 'Sent'); -- Nancy Nurse (LeadID 9) for Healthcare campaign +INSERT INTO CampaignMembers (CampaignID, ContactID, Status) VALUES +(3, 8, 'Sent'), -- Helen Healer (ContactID 8) +(3, 9, 'Responded'); -- Marcus Welby (ContactID 9) + +-- Insert more Tasks (starting from TaskID 5) +INSERT INTO Tasks (Title, Description, DueDate, Status, Priority, AssignedTo, RelatedToEntity, RelatedToID) VALUES +('Draft contract for Pinnacle Finance', 'Based on the final negotiation terms.', '2025-07-28', 'Completed', 'High', 4, 'Opportunity', 9), +('Schedule webinar with HealthFirst', 'Discuss Patient Management System demo.', '2025-08-10', 'Not Started', 'High', 4, 'Opportunity', 7), +('Research Quantum Innovations needs', 'Prepare for qualification call.', '2025-08-15', 'In Progress', 'Normal', 3, 'Opportunity', 6), +('Call Nancy Nurse to follow up', 'Follow up from trade show conversation.', '2025-08-05', 'Not Started', 'Normal', 3, 'Lead', 9); + +-- Insert more Notes (starting from NoteID 5) +INSERT INTO Notes (Content, CreatedBy, RelatedToEntity, RelatedToID) VALUES +('Pinnacle deal closed! Great work team.', 2, 'Opportunity', 9), +('GreenScape is looking for a solution before year-end for compliance reasons.', 3, 'Opportunity', 8), +('Nancy was very engaged at the booth, good prospect.', 3, 'Lead', 9); + +-- Insert more Support Tickets (starting from TicketID 5) +INSERT INTO SupportTickets (CustomerID, ContactID, Subject, Description, Status, Priority, AssignedTo) VALUES +(4, 5, 'Dashboard data not refreshing', 'The main dashboard widgets are not updating in real-time.', 'Open', 'High', 5), +(5, 6, 'Report generation is slow', 'Generating the quarterly HR report takes over 10 minutes.', 'In Progress', 'Normal', 5), +(9, 11, 'Login issue for new user', 'Fiona Funds cannot log into the new Wealth Management platform.', 'Open', 'High', 5); + +-- Insert more Ticket Comments (starting from CommentID 5) +INSERT INTO TicketComments (TicketID, Comment, CreatedBy) VALUES +(2, 'Invoice has been resent to the customer.', 5), -- Comment on existing ticket +(4, 'The calendar sync issue seems to be related to a recent Google API update. Investigating.', 5), -- Comment on existing ticket +(5, 'Escalated to engineering to check the database query performance.', 5), +(6, 'Confirmed the issue is with the real-time data service. Restarting the service.', 5); + +-- Update existing records to show data changes +UPDATE Leads SET Status = 'Contacted' WHERE LeadID = 2; -- Frank Stein +UPDATE Invoices SET Status = 'Paid' WHERE InvoiceID = 2; -- Innovate Inc. invoice diff --git a/onthology.py b/onthology.py index f5f0007c..9ae9a783 100644 --- a/onthology.py +++ b/onthology.py @@ -1,9 +1,11 @@ +"""Ontology generation module for CRM system knowledge graph.""" + from falkordb import FalkorDB from graphrag_sdk import Ontology from graphrag_sdk.models.litellm import LiteModel model = LiteModel(model_name="gemini/gemini-2.0-flash") -db = FalkorDB(host='localhost', port=6379) -kg_name = "crm_system" -ontology = Ontology.from_kg_graph(db.select_graph(kg_name), 1000000000) -ontology.save_to_graph(db.select_graph(f"{{{kg_name}}}_schema")) +db = FalkorDB(host="localhost", port=6379) +KG_NAME = "crm_system" +ontology = Ontology.from_kg_graph(db.select_graph(KG_NAME), 1000000000) +ontology.save_to_graph(db.select_graph(f"{{{KG_NAME}}}_schema")) diff --git a/poetry.lock b/poetry.lock index 486ad416..2ca47e8e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -478,6 +478,32 @@ werkzeug = ">=3.1.0" async = ["asgiref (>=3.2)"] dotenv = ["python-dotenv"] +[[package]] +name = "flask-dance" +version = "7.1.0" +description = "Doing the OAuth dance with style using Flask, requests, and oauthlib" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "flask_dance-7.1.0-py3-none-any.whl", hash = "sha256:81599328a2b3604fd4332b3d41a901cf36980c2067e5e38c44ce3b85c4e1ae9c"}, + {file = "flask_dance-7.1.0.tar.gz", hash = "sha256:6d0510e284f3d6ff05af918849791b17ef93a008628ec33f3a80578a44b51674"}, +] + +[package.dependencies] +Flask = ">=2.0.3" +oauthlib = ">=3.2" +requests = ">=2.0" +requests-oauthlib = ">=1.0.0" +urlobject = "*" +Werkzeug = "*" + +[package.extras] +docs = ["Flask-Sphinx-Themes", "betamax", "pillow (<=9.5)", "pytest", "sphinx (>=1.3)", "sphinxcontrib-seqdiag", "sphinxcontrib-spelling", "sqlalchemy (>=1.3.11)"] +signals = ["blinker"] +sqla = ["sqlalchemy (>=1.3.11)"] +test = ["betamax", "coverage", "flask-caching", "flask-login", "flask-sqlalchemy", "freezegun", "oauthlib[signedtoken]", "pytest", "pytest-mock", "responses", "sqlalchemy (>=1.3.11)"] + [[package]] name = "frozenlist" version = "1.7.0" @@ -1219,6 +1245,23 @@ files = [ {file = "multidict-6.6.2.tar.gz", hash = "sha256:c1e8b8b0523c0361a78ce9b99d9850c51cf25e1fa3c5686030ce75df6fdf2918"}, ] +[[package]] +name = "oauthlib" +version = "3.3.1" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1"}, + {file = "oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9"}, +] + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + [[package]] name = "openai" version = "1.93.0" @@ -1947,6 +1990,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +description = "OAuthlib authentication support for Requests." +optional = false +python-versions = ">=3.4" +groups = ["main"] +files = [ + {file = "requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9"}, + {file = "requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36"}, +] + +[package.dependencies] +oauthlib = ">=3.0.0" +requests = ">=2.0.0" + +[package.extras] +rsa = ["oauthlib[signedtoken] (>=3.0.0)"] + [[package]] name = "rpds-py" version = "0.25.1" @@ -2276,6 +2338,17 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "urlobject" +version = "2.4.3" +description = "A utility class for manipulating URLs." +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "URLObject-2.4.3.tar.gz", hash = "sha256:47b2e20e6ab9c8366b2f4a3566b6ff4053025dad311c4bb71279bbcfa2430caa"}, +] + [[package]] name = "werkzeug" version = "3.1.3" @@ -2436,4 +2509,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<3.13" -content-hash = "ee04db8da4b24fe0934a010563416a179dc15ee351d46a380f778fb1086a0898" +content-hash = "f883bca3ecc7074ea013650a53f67a0a78dca576a65bd5732a4be3cfc6e4a66a" diff --git a/pyproject.toml b/pyproject.toml index 5c4298ab..7c896dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ jsonschema = "^4.23.0" tqdm = "^4.67.1" boto3 = "^1.37.29" psycopg2-binary = "^2.9.9" +flask-dance = "^7.1.0" [tool.poetry.group.test.dependencies] pytest = "^8.2.0" diff --git a/requirements.txt b/requirements.txt index f5cdb6c5..5bffe872 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,61 +1,62 @@ -aiohappyeyeballs==2.6.1 ; python_version == "3.12" -aiohttp==3.12.13 ; python_version == "3.12" -aiosignal==1.3.2 ; python_version == "3.12" -annotated-types==0.7.0 ; python_version == "3.12" -anyio==4.9.0 ; python_version == "3.12" -attrs==25.3.0 ; python_version == "3.12" -blinker==1.9.0 ; python_version == "3.12" -boto3==1.38.46 ; python_version == "3.12" -botocore==1.38.46 ; python_version == "3.12" -certifi==2025.6.15 ; python_version == "3.12" -charset-normalizer==3.4.2 ; python_version == "3.12" -click==8.2.1 ; python_version == "3.12" -colorama==0.4.6 ; python_version == "3.12" and platform_system == "Windows" -distro==1.9.0 ; python_version == "3.12" -falkordb==1.1.2 ; python_version == "3.12" -filelock==3.18.0 ; python_version == "3.12" -flask==3.1.1 ; python_version == "3.12" -frozenlist==1.7.0 ; python_version == "3.12" -fsspec==2025.5.1 ; python_version == "3.12" -h11==0.16.0 ; python_version == "3.12" -hf-xet==1.1.5 ; python_version == "3.12" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64") -httpcore==1.0.9 ; python_version == "3.12" -httpx==0.28.1 ; python_version == "3.12" -huggingface-hub==0.33.1 ; python_version == "3.12" -idna==3.10 ; python_version == "3.12" -importlib-metadata==8.7.0 ; python_version == "3.12" -itsdangerous==2.2.0 ; python_version == "3.12" -jinja2==3.1.6 ; python_version == "3.12" -jiter==0.10.0 ; python_version == "3.12" -jmespath==1.0.1 ; python_version == "3.12" -jsonschema-specifications==2025.4.1 ; python_version == "3.12" -jsonschema==4.24.0 ; python_version == "3.12" -litellm==1.73.6 ; python_version == "3.12" -markupsafe==3.0.2 ; python_version == "3.12" -multidict==6.6.2 ; python_version == "3.12" -openai==1.93.0 ; python_version == "3.12" -packaging==25.0 ; python_version == "3.12" -propcache==0.3.2 ; python_version == "3.12" -pydantic-core==2.33.2 ; python_version == "3.12" -pydantic==2.11.7 ; python_version == "3.12" -pyjwt==2.9.0 ; python_version == "3.12" -python-dateutil==2.9.0.post0 ; python_version == "3.12" -python-dotenv==1.1.1 ; python_version == "3.12" -pyyaml==6.0.2 ; python_version == "3.12" -redis==5.3.0 ; python_version == "3.12" -referencing==0.36.2 ; python_version == "3.12" -regex==2024.11.6 ; python_version == "3.12" -requests==2.32.4 ; python_version == "3.12" -rpds-py==0.25.1 ; python_version == "3.12" -s3transfer==0.13.0 ; python_version == "3.12" -six==1.17.0 ; python_version == "3.12" -sniffio==1.3.1 ; python_version == "3.12" -tiktoken==0.9.0 ; python_version == "3.12" -tokenizers==0.21.2 ; python_version == "3.12" -tqdm==4.67.1 ; python_version == "3.12" -typing-extensions==4.14.0 ; python_version == "3.12" -typing-inspection==0.4.1 ; python_version == "3.12" -urllib3==2.5.0 ; python_version == "3.12" -werkzeug==3.1.3 ; python_version == "3.12" -yarl==1.20.1 ; python_version == "3.12" -zipp==3.23.0 ; python_version == "3.12" +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiosignal==1.3.2 +annotated-types==0.7.0 +anyio==4.9.0 +attrs==25.3.0 +blinker==1.9.0 +boto3==1.38.46 +botocore==1.38.46 +certifi==2025.6.15 +charset-normalizer==3.4.2 +click==8.2.1 +distro==1.9.0 +falkordb==1.1.2 +filelock==3.18.0 +flask==3.1.1 +flask-dance==7.1.0 +frozenlist==1.7.0 +fsspec==2025.5.1 +h11==0.16.0 +hf-xet==1.1.5 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.33.1 +idna==3.10 +importlib-metadata==8.7.0 +itsdangerous==2.2.0 +jinja2==3.1.6 +jiter==0.10.0 +jmespath==1.0.1 +jsonschema-specifications==2025.4.1 +jsonschema==4.24.0 +litellm==1.73.6 +markupsafe==3.0.2 +multidict==6.6.2 +openai==1.93.0 +packaging==25.0 +propcache==0.3.2 +pydantic-core==2.33.2 +pydantic==2.11.7 +pyjwt==2.9.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +pyyaml==6.0.2 +redis==5.3.0 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.4 +rpds-py==0.25.1 +s3transfer==0.13.0 +six==1.17.0 +sniffio==1.3.1 +tiktoken==0.9.0 +tokenizers==0.21.2 +tqdm==4.67.1 +typing-extensions==4.14.0 +typing-inspection==0.4.1 +urllib3==2.5.0 +werkzeug==3.1.3 +yarl==1.20.1 +zipp==3.23.0 +psycopg2-binary==2.9.9 diff --git a/start.sh b/start.sh new file mode 100644 index 00000000..c0db7ef6 --- /dev/null +++ b/start.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e + + +# Set default values if not set +FALKORDB_HOST="${FALKORDB_HOST:-localhost}" +FALKORDB_PORT="${FALKORDB_PORT:-6379}" + +# Start FalkorDB Redis server in background +redis-server --loadmodule /var/lib/falkordb/bin/falkordb.so & + +# Wait until FalkorDB is ready +echo "Waiting for FalkorDB to start on $FALKORDB_HOST:$FALKORDB_PORT..." + +while ! nc -z "$FALKORDB_HOST" "$FALKORDB_PORT"; do + sleep 0.5 +done + + +echo "FalkorDB is up - launching Flask..." +exec python3 -m flask --app api.index run --host=0.0.0.0 --port=5000 \ No newline at end of file