Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 21 additions & 46 deletions api/loaders/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Base loader module providing abstract base class for data loaders."""

from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Any, Tuple, TYPE_CHECKING
from api.config import Config
from typing import AsyncGenerator, List, Any, TYPE_CHECKING


class BaseLoader(ABC):
Expand All @@ -24,69 +23,45 @@ async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]:

@staticmethod
@abstractmethod
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
def _execute_sample_query(
cursor, table_name: str, col_name: str, sample_size: int = 3
) -> List[Any]:
"""
Execute query to get total count and distinct count for a column.
Execute query to get random sample values for a column.
Args:
cursor: Database cursor
table_name: Name of the table
col_name: Name of the column
sample_size: Number of random samples to retrieve (default: 3)
Returns:
Tuple of (total_count, distinct_count)
"""

@staticmethod
@abstractmethod
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
"""
Execute query to get distinct values for a column.
Args:
cursor: Database cursor
table_name: Name of the table
col_name: Name of the column
Returns:
List of distinct values
List of sample values
"""

@classmethod
def extract_distinct_values_for_column(
cls, cursor, table_name: str, col_name: str
) -> List[str]:
def extract_sample_values_for_column(
cls, cursor, table_name: str, col_name: str, sample_size: int = 3
) -> List[Any]:
"""
Extract distinct values for a column if it meets the criteria for inclusion.
Extract random sample values for a column to provide balanced descriptions.
Args:
cursor: Database cursor
table_name: Name of the table
col_name: Name of the column
sample_size: Number of random samples to retrieve (default: 3)
Returns:
List of formatted distinct values to add to description, or empty list
List of sample values (converted to strings), or empty list
"""
# Get row counts using database-specific implementation
rows_count, distinct_count = cls._execute_count_query(
cursor, table_name, col_name
)

max_distinct = Config.DB_MAX_DISTINCT
uniqueness_threshold = Config.DB_UNIQUENESS_THRESHOLD

if 0 < distinct_count < max_distinct and distinct_count < (
uniqueness_threshold * rows_count
):
# Get distinct values using database-specific implementation
distinct_values = cls._execute_distinct_query(cursor, table_name, col_name)

if distinct_values:
# Check first value type to avoid objects like dict/bytes
first_val = distinct_values[0]
if isinstance(first_val, (str, int)):
return [
f"(Optional values: {', '.join(f'({str(v)})' for v in distinct_values)})"
]
# Get sample values using database-specific implementation
sample_values = cls._execute_sample_query(cursor, table_name, col_name, sample_size)

if sample_values:
# Check first value type to avoid objects like dict/bytes
first_val = sample_values[0]
if isinstance(first_val, (str, int, float)):
return [str(v) for v in sample_values]

return []
13 changes: 11 additions & 2 deletions api/loaders/graph_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from api.config import Config
from api.extensions import db
from api.utils import generate_db_description
from api.utils import generate_db_description, create_combined_description


async def load_to_graph( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
Expand All @@ -31,6 +31,8 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
embedding_model = Config.EMBEDDING_MODEL
vec_len = embedding_model.get_vector_size()

create_combined_description(entities)

try:
# Create vector indices
await graph.query(
Expand Down Expand Up @@ -123,6 +125,13 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
embed_columns.extend(embedding_result)
idx = 0

# Combine description with sample values after embedding is created
final_description = col_info["description"]
sample_values = col_info.get("sample_values", [])
if sample_values:
sample_values_str = f"(Sample values: {', '.join(f'({v})' for v in sample_values)})"
final_description = f"{final_description} {sample_values_str}"

await graph.query(
"""
MATCH (t:Table {name: $table_name})
Expand All @@ -141,7 +150,7 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
"type": col_info.get("type", "unknown"),
"nullable": col_info.get("null", "unknown"),
"key": col_info.get("key", "unknown"),
"description": col_info["description"],
"description": final_description,
"embedding": embed_columns[idx],
},
)
Expand Down
45 changes: 18 additions & 27 deletions api/loaders/mysql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import decimal
import logging
import re
from typing import AsyncGenerator, Tuple, Dict, Any, List
from typing import AsyncGenerator, Dict, Any, List, Tuple

import tqdm
import pymysql
Expand Down Expand Up @@ -54,33 +54,24 @@ class MySQLLoader(BaseLoader):
]

@staticmethod
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
def _execute_sample_query(
cursor, table_name: str, col_name: str, sample_size: int = 3
) -> List[Any]:
"""
Execute query to get total count and distinct count for a column.
MySQL implementation returning counts from dictionary-style results.
Execute query to get random sample values for a column.
MySQL implementation using ORDER BY RAND() for random sampling.
"""
query = f"""
SELECT COUNT(*) AS total_count,
COUNT(DISTINCT `{col_name}`) AS distinct_count
FROM `{table_name}`;
SELECT DISTINCT `{col_name}`
FROM `{table_name}`
WHERE `{col_name}` IS NOT NULL
ORDER BY RAND()
LIMIT %s;
"""
cursor.execute(query, (sample_size,))

cursor.execute(query)
output = cursor.fetchall()
first_result = output[0]
return first_result['total_count'], first_result['distinct_count']

@staticmethod
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
"""
Execute query to get distinct values for a column.
MySQL implementation handling dictionary-style results.
"""
query = f"SELECT DISTINCT `{col_name}` FROM `{table_name}`;"
cursor.execute(query)

distinct_results = cursor.fetchall()
return [row[col_name] for row in distinct_results if row[col_name] is not None]
sample_results = cursor.fetchall()
return [row[col_name] for row in sample_results if row[col_name] is not None]

@staticmethod
def _serialize_value(value):
Expand Down Expand Up @@ -324,18 +315,18 @@ def extract_columns_info(cursor, db_name: str, table_name: str) -> Dict[str, Any
if column_default is not None:
description_parts.append(f"(Default: {column_default})")

# Add distinct values if applicable
distinct_values_desc = MySQLLoader.extract_distinct_values_for_column(
# Extract sample values for the column (stored separately, not in description)
sample_values = MySQLLoader.extract_sample_values_for_column(
cursor, table_name, col_name
)
description_parts.extend(distinct_values_desc)

columns_info[col_name] = {
'type': data_type,
'null': is_nullable,
'key': key_type,
'description': ' '.join(description_parts),
'default': column_default
'default': column_default,
'sample_values': sample_values
}

return columns_info
Expand Down
50 changes: 21 additions & 29 deletions api/loaders/postgres_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
import decimal
import logging
from typing import AsyncGenerator, Tuple, Dict, Any, List
from typing import AsyncGenerator, Dict, Any, List, Tuple

import psycopg2
from psycopg2 import sql
Expand Down Expand Up @@ -51,37 +51,29 @@ class PostgresLoader(BaseLoader):
]

@staticmethod
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
def _execute_sample_query(
cursor, table_name: str, col_name: str, sample_size: int = 3
) -> List[Any]:
"""
Execute query to get total count and distinct count for a column.
PostgreSQL implementation returning counts from tuple-style results.
Execute query to get random sample values for a column.
PostgreSQL implementation using ORDER BY RANDOM() for random sampling.
"""
query = sql.SQL("""
SELECT COUNT(*) AS total_count,
COUNT(DISTINCT {col}) AS distinct_count
FROM {table};
SELECT {col}
FROM (
SELECT DISTINCT {col}
FROM {table}
WHERE {col} IS NOT NULL
) AS distinct_vals
ORDER BY RANDOM()
LIMIT %s;
""").format(
col=sql.Identifier(col_name),
table=sql.Identifier(table_name)
)
cursor.execute(query)
output = cursor.fetchall()
first_result = output[0]
return first_result[0], first_result[1]

@staticmethod
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
"""
Execute query to get distinct values for a column.
PostgreSQL implementation handling tuple-style results.
"""
query = sql.SQL("SELECT DISTINCT {col} FROM {table};").format(
col=sql.Identifier(col_name),
table=sql.Identifier(table_name)
)
cursor.execute(query)
distinct_results = cursor.fetchall()
return [row[0] for row in distinct_results if row[0] is not None]
cursor.execute(query, (sample_size,))
sample_results = cursor.fetchall()
return [row[0] for row in sample_results if row[0] is not None]

@staticmethod
def _serialize_value(value):
Expand Down Expand Up @@ -279,18 +271,18 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
if column_default:
description_parts.append(f"(Default: {column_default})")

# Add distinct values if applicable
distinct_values_desc = PostgresLoader.extract_distinct_values_for_column(
# Extract sample values for the column (stored separately, not in description)
sample_values = PostgresLoader.extract_sample_values_for_column(
cursor, table_name, col_name
)
description_parts.extend(distinct_values_desc)

columns_info[col_name] = {
'type': data_type,
'null': is_nullable,
'key': key_type,
'description': ' '.join(description_parts),
'default': column_default
'default': column_default,
'sample_values': sample_values
}


Expand Down
Loading