From 19958a5498d7b958113e07072c52d9a38d2b9236 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:09:06 +0200 Subject: [PATCH 1/6] unified-descriptions --- api/loaders/base_loader.py | 65 ++++++++++------------------------ api/loaders/graph_loader.py | 13 +++++-- api/loaders/mysql_loader.py | 43 +++++++++------------- api/loaders/postgres_loader.py | 48 ++++++++++--------------- api/utils.py | 61 +++++++++++++++++++++++++++++-- 5 files changed, 123 insertions(+), 107 deletions(-) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 91141606..39c822d1 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -1,8 +1,7 @@ """Base loader module providing abstract base class for data loaders.""" from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Any, Tuple, TYPE_CHECKING -from api.config import Config +from typing import AsyncGenerator, List, Any, TYPE_CHECKING class BaseLoader(ABC): @@ -24,69 +23,43 @@ async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]: @staticmethod @abstractmethod - def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]: + def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: """ - Execute query to get total count and distinct count for a column. + Execute query to get random sample values for a column. Args: cursor: Database cursor table_name: Name of the table col_name: Name of the column + sample_size: Number of random samples to retrieve (default: 3) Returns: - Tuple of (total_count, distinct_count) - """ - - @staticmethod - @abstractmethod - def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]: - """ - Execute query to get distinct values for a column. - - Args: - cursor: Database cursor - table_name: Name of the table - col_name: Name of the column - - Returns: - List of distinct values + List of sample values """ @classmethod - def extract_distinct_values_for_column( - cls, cursor, table_name: str, col_name: str - ) -> List[str]: + def extract_sample_values_for_column( + cls, cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: """ - Extract distinct values for a column if it meets the criteria for inclusion. + Extract random sample values for a column to provide balanced descriptions. Args: cursor: Database cursor table_name: Name of the table col_name: Name of the column + sample_size: Number of random samples to retrieve (default: 3) Returns: - List of formatted distinct values to add to description, or empty list + List of sample values (raw values, not formatted), or empty list """ - # Get row counts using database-specific implementation - rows_count, distinct_count = cls._execute_count_query( - cursor, table_name, col_name - ) - - max_distinct = Config.DB_MAX_DISTINCT - uniqueness_threshold = Config.DB_UNIQUENESS_THRESHOLD - - if 0 < distinct_count < max_distinct and distinct_count < ( - uniqueness_threshold * rows_count - ): - # Get distinct values using database-specific implementation - distinct_values = cls._execute_distinct_query(cursor, table_name, col_name) - - if distinct_values: - # Check first value type to avoid objects like dict/bytes - first_val = distinct_values[0] - if isinstance(first_val, (str, int)): - return [ - f"(Optional values: {', '.join(f'({str(v)})' for v in distinct_values)})" - ] + # Get sample values using database-specific implementation + sample_values = cls._execute_sample_query(cursor, table_name, col_name, sample_size) + + if sample_values: + # Check first value type to avoid objects like dict/bytes + first_val = sample_values[0] + if isinstance(first_val, (str, int, float)): + return [str(v) for v in sample_values] return [] diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index d67b06a4..12799688 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -6,7 +6,7 @@ from api.config import Config from api.extensions import db -from api.utils import generate_db_description +from api.utils import generate_db_description, create_combined_description async def load_to_graph( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals @@ -31,6 +31,8 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position embedding_model = Config.EMBEDDING_MODEL vec_len = embedding_model.get_vector_size() + enitites = create_combined_description(entities) + try: # Create vector indices await graph.query( @@ -123,6 +125,13 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position embed_columns.extend(embedding_result) idx = 0 + # Combine description with sample values after embedding is created + final_description = col_info["description"] + sample_values = col_info.get("sample_values", []) + if sample_values: + sample_values_str = f"(Sample values: {', '.join(f'({v})' for v in sample_values)})" + final_description = f"{final_description} {sample_values_str}" + await graph.query( """ MATCH (t:Table {name: $table_name}) @@ -141,7 +150,7 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position "type": col_info.get("type", "unknown"), "nullable": col_info.get("null", "unknown"), "key": col_info.get("key", "unknown"), - "description": col_info["description"], + "description": final_description, "embedding": embed_columns[idx], }, ) diff --git a/api/loaders/mysql_loader.py b/api/loaders/mysql_loader.py index 2938b263..33282a1c 100644 --- a/api/loaders/mysql_loader.py +++ b/api/loaders/mysql_loader.py @@ -4,7 +4,7 @@ import decimal import logging import re -from typing import AsyncGenerator, Tuple, Dict, Any, List +from typing import AsyncGenerator, Dict, Any, List, Tuple import tqdm import pymysql @@ -54,33 +54,22 @@ class MySQLLoader(BaseLoader): ] @staticmethod - def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]: + def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: """ - Execute query to get total count and distinct count for a column. - MySQL implementation returning counts from dictionary-style results. + Execute query to get random sample values for a column. + MySQL implementation using ORDER BY RAND() for random sampling. """ query = f""" - SELECT COUNT(*) AS total_count, - COUNT(DISTINCT `{col_name}`) AS distinct_count - FROM `{table_name}`; + SELECT DISTINCT `{col_name}` + FROM `{table_name}` + WHERE `{col_name}` IS NOT NULL + ORDER BY RAND() + LIMIT %s; """ + cursor.execute(query, (sample_size,)) - cursor.execute(query) - output = cursor.fetchall() - first_result = output[0] - return first_result['total_count'], first_result['distinct_count'] - - @staticmethod - def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]: - """ - Execute query to get distinct values for a column. - MySQL implementation handling dictionary-style results. - """ - query = f"SELECT DISTINCT `{col_name}` FROM `{table_name}`;" - cursor.execute(query) - - distinct_results = cursor.fetchall() - return [row[col_name] for row in distinct_results if row[col_name] is not None] + sample_results = cursor.fetchall() + return [row[col_name] for row in sample_results if row[col_name] is not None] @staticmethod def _serialize_value(value): @@ -324,18 +313,18 @@ def extract_columns_info(cursor, db_name: str, table_name: str) -> Dict[str, Any if column_default is not None: description_parts.append(f"(Default: {column_default})") - # Add distinct values if applicable - distinct_values_desc = MySQLLoader.extract_distinct_values_for_column( + # Extract sample values for the column (stored separately, not in description) + sample_values = MySQLLoader.extract_sample_values_for_column( cursor, table_name, col_name ) - description_parts.extend(distinct_values_desc) columns_info[col_name] = { 'type': data_type, 'null': is_nullable, 'key': key_type, 'description': ' '.join(description_parts), - 'default': column_default + 'default': column_default, + 'sample_values': sample_values } return columns_info diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index cd44e77f..a33801a6 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -4,7 +4,7 @@ import datetime import decimal import logging -from typing import AsyncGenerator, Tuple, Dict, Any, List +from typing import AsyncGenerator, Dict, Any, List, Tuple import psycopg2 from psycopg2 import sql @@ -51,37 +51,27 @@ class PostgresLoader(BaseLoader): ] @staticmethod - def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]: + def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: """ - Execute query to get total count and distinct count for a column. - PostgreSQL implementation returning counts from tuple-style results. + Execute query to get random sample values for a column. + PostgreSQL implementation using ORDER BY RANDOM() for random sampling. """ query = sql.SQL(""" - SELECT COUNT(*) AS total_count, - COUNT(DISTINCT {col}) AS distinct_count - FROM {table}; + SELECT {col} + FROM ( + SELECT DISTINCT {col} + FROM {table} + WHERE {col} IS NOT NULL + ) AS distinct_vals + ORDER BY RANDOM() + LIMIT %s; """).format( col=sql.Identifier(col_name), table=sql.Identifier(table_name) ) - cursor.execute(query) - output = cursor.fetchall() - first_result = output[0] - return first_result[0], first_result[1] - - @staticmethod - def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]: - """ - Execute query to get distinct values for a column. - PostgreSQL implementation handling tuple-style results. - """ - query = sql.SQL("SELECT DISTINCT {col} FROM {table};").format( - col=sql.Identifier(col_name), - table=sql.Identifier(table_name) - ) - cursor.execute(query) - distinct_results = cursor.fetchall() - return [row[0] for row in distinct_results if row[0] is not None] + cursor.execute(query, (sample_size,)) + sample_results = cursor.fetchall() + return [row[0] for row in sample_results if row[0] is not None] @staticmethod def _serialize_value(value): @@ -279,18 +269,18 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: if column_default: description_parts.append(f"(Default: {column_default})") - # Add distinct values if applicable - distinct_values_desc = PostgresLoader.extract_distinct_values_for_column( + # Extract sample values for the column (stored separately, not in description) + sample_values = PostgresLoader.extract_sample_values_for_column( cursor, table_name, col_name ) - description_parts.extend(distinct_values_desc) columns_info[col_name] = { 'type': data_type, 'null': is_nullable, 'key': key_type, 'description': ' '.join(description_parts), - 'default': column_default + 'default': column_default, + 'sample_values': sample_values } diff --git a/api/utils.py b/api/utils.py index 14dd09d6..f3b0e52a 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,11 +1,66 @@ """Utility functions for the text2sql API.""" +import json +from typing import Any, Dict, List +from api.config import Config +from litellm import completion, batch_completion -from typing import List +def create_combined_description(table_info: Dict[str, Dict[str, Any]], batch_size: int = 10) -> Dict[str, Dict[str, Any]]: + """ + Create a combined description from a dictionary of table descriptions. -from litellm import completion + Args: + table_info (Dict[str, Dict[str, Any]]): Mapping of table names to their metadata. + Returns: + Dict[str, Dict[str, Any]]: Updated mapping containing descriptions. + """ + if not isinstance(table_info, dict): + raise TypeError("table_info must be a dictionary keyed by table name.") -from api.config import Config + messages_list = [] + table_keys = [] + + system_prompt = ( + "You are a database table description generator. " + "Generate ONE concise sentence starting with the table name, describing what the table stores, " + "using present tense. Do not add explanations." + ) + + user_prompt_template = ( + "Table Name: {table_name}\n" + "Table Schema: {table_prop}\n" + "Provide a concise description of this table." + ) + + for table_name, table_prop in table_info.items(): + table_prop.pop("col_descriptions") # The col_descriptions property is duplicated in the schema (columns has it) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt_template.format(table_name=table_name, table_prop=json.dumps(table_prop))}, + ] + + messages_list.append(messages) + table_keys.append(table_name) + for batch_start in range(0, len(messages_list), batch_size): + batch_messages = messages_list[batch_start : batch_start + batch_size] + response = batch_completion( + model=Config.COMPLETION_MODEL, + messages=batch_messages, + temperature=0, + max_tokens=50, + ) + + for offset, batch_response in enumerate(response): + table_index = batch_start + offset + if table_index >= len(table_keys): + break + table_name = table_keys[table_index] + if isinstance(batch_response, Exception): + table_info[table_name]["description"] = table_name + else: + table_info[table_name]["description"] = batch_response.choices[0].message["content"].strip() + + return table_info def generate_db_description( db_name: str, From db206cc09e6f70b3212c9b746a85a7b257d09c51 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:30:02 +0200 Subject: [PATCH 2/6] fix-lint --- api/loaders/base_loader.py | 4 +++- api/loaders/graph_loader.py | 2 +- api/loaders/mysql_loader.py | 4 +++- api/loaders/postgres_loader.py | 4 +++- api/utils.py | 37 +++++++++++++++++++++++----------- 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 39c822d1..08d6aad0 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -23,7 +23,9 @@ async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]: @staticmethod @abstractmethod - def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: + def _execute_sample_query( + cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: """ Execute query to get random sample values for a column. diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index 12799688..855b2201 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -31,7 +31,7 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position embedding_model = Config.EMBEDDING_MODEL vec_len = embedding_model.get_vector_size() - enitites = create_combined_description(entities) + create_combined_description(entities) try: # Create vector indices diff --git a/api/loaders/mysql_loader.py b/api/loaders/mysql_loader.py index 33282a1c..9825b2e4 100644 --- a/api/loaders/mysql_loader.py +++ b/api/loaders/mysql_loader.py @@ -54,7 +54,9 @@ class MySQLLoader(BaseLoader): ] @staticmethod - def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: + def _execute_sample_query( + cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: """ Execute query to get random sample values for a column. MySQL implementation using ORDER BY RAND() for random sampling. diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index a33801a6..be0be497 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -51,7 +51,9 @@ class PostgresLoader(BaseLoader): ] @staticmethod - def _execute_sample_query(cursor, table_name: str, col_name: str, sample_size: int = 3) -> List[Any]: + def _execute_sample_query( + cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: """ Execute query to get random sample values for a column. PostgreSQL implementation using ORDER BY RANDOM() for random sampling. diff --git a/api/utils.py b/api/utils.py index f3b0e52a..1c938184 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,10 +1,15 @@ """Utility functions for the text2sql API.""" import json from typing import Any, Dict, List -from api.config import Config + from litellm import completion, batch_completion -def create_combined_description(table_info: Dict[str, Dict[str, Any]], batch_size: int = 10) -> Dict[str, Dict[str, Any]]: +from api.config import Config + + +def create_combined_description( + table_info: Dict[str, Dict[str, Any]], batch_size: int = 10 +) -> Dict[str, Dict[str, Any]]: """ Create a combined description from a dictionary of table descriptions. @@ -18,24 +23,31 @@ def create_combined_description(table_info: Dict[str, Dict[str, Any]], batch_siz messages_list = [] table_keys = [] - + system_prompt = ( "You are a database table description generator. " - "Generate ONE concise sentence starting with the table name, describing what the table stores, " - "using present tense. Do not add explanations." + "Generate ONE concise sentence starting with the table name, " + "describing what the table stores, using present tense. " + "Do not add explanations." ) - + user_prompt_template = ( "Table Name: {table_name}\n" "Table Schema: {table_prop}\n" "Provide a concise description of this table." ) - + for table_name, table_prop in table_info.items(): - table_prop.pop("col_descriptions") # The col_descriptions property is duplicated in the schema (columns has it) + # The col_descriptions property is duplicated in the schema (columns has it) + table_prop.pop("col_descriptions") messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt_template.format(table_name=table_name, table_prop=json.dumps(table_prop))}, + { + "role": "user", + "content": user_prompt_template.format( + table_name=table_name, table_prop=json.dumps(table_prop) + ), + }, ] messages_list.append(messages) @@ -49,7 +61,7 @@ def create_combined_description(table_info: Dict[str, Dict[str, Any]], batch_siz temperature=0, max_tokens=50, ) - + for offset, batch_response in enumerate(response): table_index = batch_start + offset if table_index >= len(table_keys): @@ -58,8 +70,9 @@ def create_combined_description(table_info: Dict[str, Dict[str, Any]], batch_siz if isinstance(batch_response, Exception): table_info[table_name]["description"] = table_name else: - table_info[table_name]["description"] = batch_response.choices[0].message["content"].strip() - + content = batch_response.choices[0].message["content"].strip() + table_info[table_name]["description"] = content + return table_info def generate_db_description( From 6aeb6993dbf0aecf9c80af4116b61286711f5cc8 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:33:27 +0200 Subject: [PATCH 3/6] Update api/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/utils.py b/api/utils.py index 1c938184..fdf9d1c7 100644 --- a/api/utils.py +++ b/api/utils.py @@ -15,6 +15,7 @@ def create_combined_description( Args: table_info (Dict[str, Dict[str, Any]]): Mapping of table names to their metadata. + batch_size (int): Number of tables to process per batch when calling the LLM (default: 10). Returns: Dict[str, Dict[str, Any]]: Updated mapping containing descriptions. """ From a3d8d95fad9811ea2f5702e169863a26e2dc1755 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:43:24 +0200 Subject: [PATCH 4/6] fix-any-usage --- api/utils.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/api/utils.py b/api/utils.py index fdf9d1c7..530ca646 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,23 +1,49 @@ """Utility functions for the text2sql API.""" import json -from typing import Any, Dict, List +from typing import Dict, List, Optional, TypedDict from litellm import completion, batch_completion from api.config import Config -def create_combined_description( - table_info: Dict[str, Dict[str, Any]], batch_size: int = 10 -) -> Dict[str, Dict[str, Any]]: +class ForeignKeyInfo(TypedDict): + """Foreign key constraint information.""" + constraint_name: str + column: str + referenced_table: str + referenced_column: str + + +class ColumnInfo(TypedDict): + """Column metadata information.""" + type: str + null: str + key: str + description: str + default: Optional[str] + sample_values: List[str] + + +class TableInfo(TypedDict): + """Table metadata information.""" + description: str + columns: Dict[str, ColumnInfo] + foreign_keys: List[ForeignKeyInfo] + col_descriptions: List[str] + + +def create_combined_description( # pylint: disable=too-many-locals + table_info: Dict[str, TableInfo], batch_size: int = 10 +) -> Dict[str, TableInfo]: """ Create a combined description from a dictionary of table descriptions. Args: - table_info (Dict[str, Dict[str, Any]]): Mapping of table names to their metadata. + table_info (Dict[str, TableInfo]): Mapping of table names to their metadata. batch_size (int): Number of tables to process per batch when calling the LLM (default: 10). Returns: - Dict[str, Dict[str, Any]]: Updated mapping containing descriptions. + Dict[str, TableInfo]: Updated mapping containing descriptions. """ if not isinstance(table_info, dict): raise TypeError("table_info must be a dictionary keyed by table name.") From 3f45eb1f44ec0d3fa1f6403d44fcd93be5c98243 Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:46:52 +0200 Subject: [PATCH 5/6] fix-docs --- api/loaders/base_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 08d6aad0..55e18ec7 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -53,7 +53,7 @@ def extract_sample_values_for_column( sample_size: Number of random samples to retrieve (default: 3) Returns: - List of sample values (raw values, not formatted), or empty list + List of sample values (converted to strings), or empty list """ # Get sample values using database-specific implementation sample_values = cls._execute_sample_query(cursor, table_name, col_name, sample_size) From cdea661cf0a24e2632b5cc274ded0f649df7bc3a Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Wed, 7 Jan 2026 15:51:42 +0200 Subject: [PATCH 6/6] fix-mu-error --- api/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/utils.py b/api/utils.py index 530ca646..845ca604 100644 --- a/api/utils.py +++ b/api/utils.py @@ -66,7 +66,8 @@ def create_combined_description( # pylint: disable=too-many-locals for table_name, table_prop in table_info.items(): # The col_descriptions property is duplicated in the schema (columns has it) - table_prop.pop("col_descriptions") + table_prop = table_prop.copy() + table_prop.pop("col_descriptions", None) messages = [ {"role": "system", "content": system_prompt}, {