Skip to content

Commit f0d79f7

Browse files
authored
Merge pull request #366 from FalkorDB/description-improvements
Description improvements - Unified
2 parents 578a6ab + cdea661 commit f0d79f7

File tree

5 files changed

+170
-107
lines changed

5 files changed

+170
-107
lines changed

api/loaders/base_loader.py

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Base loader module providing abstract base class for data loaders."""
22

33
from abc import ABC, abstractmethod
4-
from typing import AsyncGenerator, List, Any, Tuple, TYPE_CHECKING
5-
from api.config import Config
4+
from typing import AsyncGenerator, List, Any, TYPE_CHECKING
65

76

87
class BaseLoader(ABC):
@@ -24,69 +23,45 @@ async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]:
2423

2524
@staticmethod
2625
@abstractmethod
27-
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
26+
def _execute_sample_query(
27+
cursor, table_name: str, col_name: str, sample_size: int = 3
28+
) -> List[Any]:
2829
"""
29-
Execute query to get total count and distinct count for a column.
30+
Execute query to get random sample values for a column.
3031
3132
Args:
3233
cursor: Database cursor
3334
table_name: Name of the table
3435
col_name: Name of the column
36+
sample_size: Number of random samples to retrieve (default: 3)
3537
3638
Returns:
37-
Tuple of (total_count, distinct_count)
38-
"""
39-
40-
@staticmethod
41-
@abstractmethod
42-
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
43-
"""
44-
Execute query to get distinct values for a column.
45-
46-
Args:
47-
cursor: Database cursor
48-
table_name: Name of the table
49-
col_name: Name of the column
50-
51-
Returns:
52-
List of distinct values
39+
List of sample values
5340
"""
5441

5542
@classmethod
56-
def extract_distinct_values_for_column(
57-
cls, cursor, table_name: str, col_name: str
58-
) -> List[str]:
43+
def extract_sample_values_for_column(
44+
cls, cursor, table_name: str, col_name: str, sample_size: int = 3
45+
) -> List[Any]:
5946
"""
60-
Extract distinct values for a column if it meets the criteria for inclusion.
47+
Extract random sample values for a column to provide balanced descriptions.
6148
6249
Args:
6350
cursor: Database cursor
6451
table_name: Name of the table
6552
col_name: Name of the column
53+
sample_size: Number of random samples to retrieve (default: 3)
6654
6755
Returns:
68-
List of formatted distinct values to add to description, or empty list
56+
List of sample values (converted to strings), or empty list
6957
"""
70-
# Get row counts using database-specific implementation
71-
rows_count, distinct_count = cls._execute_count_query(
72-
cursor, table_name, col_name
73-
)
74-
75-
max_distinct = Config.DB_MAX_DISTINCT
76-
uniqueness_threshold = Config.DB_UNIQUENESS_THRESHOLD
77-
78-
if 0 < distinct_count < max_distinct and distinct_count < (
79-
uniqueness_threshold * rows_count
80-
):
81-
# Get distinct values using database-specific implementation
82-
distinct_values = cls._execute_distinct_query(cursor, table_name, col_name)
83-
84-
if distinct_values:
85-
# Check first value type to avoid objects like dict/bytes
86-
first_val = distinct_values[0]
87-
if isinstance(first_val, (str, int)):
88-
return [
89-
f"(Optional values: {', '.join(f'({str(v)})' for v in distinct_values)})"
90-
]
58+
# Get sample values using database-specific implementation
59+
sample_values = cls._execute_sample_query(cursor, table_name, col_name, sample_size)
60+
61+
if sample_values:
62+
# Check first value type to avoid objects like dict/bytes
63+
first_val = sample_values[0]
64+
if isinstance(first_val, (str, int, float)):
65+
return [str(v) for v in sample_values]
9166

9267
return []

api/loaders/graph_loader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from api.config import Config
88
from api.extensions import db
9-
from api.utils import generate_db_description
9+
from api.utils import generate_db_description, create_combined_description
1010

1111

1212
async def load_to_graph( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
@@ -31,6 +31,8 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
3131
embedding_model = Config.EMBEDDING_MODEL
3232
vec_len = embedding_model.get_vector_size()
3333

34+
create_combined_description(entities)
35+
3436
try:
3537
# Create vector indices
3638
await graph.query(
@@ -123,6 +125,13 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
123125
embed_columns.extend(embedding_result)
124126
idx = 0
125127

128+
# Combine description with sample values after embedding is created
129+
final_description = col_info["description"]
130+
sample_values = col_info.get("sample_values", [])
131+
if sample_values:
132+
sample_values_str = f"(Sample values: {', '.join(f'({v})' for v in sample_values)})"
133+
final_description = f"{final_description} {sample_values_str}"
134+
126135
await graph.query(
127136
"""
128137
MATCH (t:Table {name: $table_name})
@@ -141,7 +150,7 @@ async def load_to_graph( # pylint: disable=too-many-arguments,too-many-position
141150
"type": col_info.get("type", "unknown"),
142151
"nullable": col_info.get("null", "unknown"),
143152
"key": col_info.get("key", "unknown"),
144-
"description": col_info["description"],
153+
"description": final_description,
145154
"embedding": embed_columns[idx],
146155
},
147156
)

api/loaders/mysql_loader.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import decimal
55
import logging
66
import re
7-
from typing import AsyncGenerator, Tuple, Dict, Any, List
7+
from typing import AsyncGenerator, Dict, Any, List, Tuple
88

99
import tqdm
1010
import pymysql
@@ -54,33 +54,24 @@ class MySQLLoader(BaseLoader):
5454
]
5555

5656
@staticmethod
57-
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
57+
def _execute_sample_query(
58+
cursor, table_name: str, col_name: str, sample_size: int = 3
59+
) -> List[Any]:
5860
"""
59-
Execute query to get total count and distinct count for a column.
60-
MySQL implementation returning counts from dictionary-style results.
61+
Execute query to get random sample values for a column.
62+
MySQL implementation using ORDER BY RAND() for random sampling.
6163
"""
6264
query = f"""
63-
SELECT COUNT(*) AS total_count,
64-
COUNT(DISTINCT `{col_name}`) AS distinct_count
65-
FROM `{table_name}`;
65+
SELECT DISTINCT `{col_name}`
66+
FROM `{table_name}`
67+
WHERE `{col_name}` IS NOT NULL
68+
ORDER BY RAND()
69+
LIMIT %s;
6670
"""
71+
cursor.execute(query, (sample_size,))
6772

68-
cursor.execute(query)
69-
output = cursor.fetchall()
70-
first_result = output[0]
71-
return first_result['total_count'], first_result['distinct_count']
72-
73-
@staticmethod
74-
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
75-
"""
76-
Execute query to get distinct values for a column.
77-
MySQL implementation handling dictionary-style results.
78-
"""
79-
query = f"SELECT DISTINCT `{col_name}` FROM `{table_name}`;"
80-
cursor.execute(query)
81-
82-
distinct_results = cursor.fetchall()
83-
return [row[col_name] for row in distinct_results if row[col_name] is not None]
73+
sample_results = cursor.fetchall()
74+
return [row[col_name] for row in sample_results if row[col_name] is not None]
8475

8576
@staticmethod
8677
def _serialize_value(value):
@@ -324,18 +315,18 @@ def extract_columns_info(cursor, db_name: str, table_name: str) -> Dict[str, Any
324315
if column_default is not None:
325316
description_parts.append(f"(Default: {column_default})")
326317

327-
# Add distinct values if applicable
328-
distinct_values_desc = MySQLLoader.extract_distinct_values_for_column(
318+
# Extract sample values for the column (stored separately, not in description)
319+
sample_values = MySQLLoader.extract_sample_values_for_column(
329320
cursor, table_name, col_name
330321
)
331-
description_parts.extend(distinct_values_desc)
332322

333323
columns_info[col_name] = {
334324
'type': data_type,
335325
'null': is_nullable,
336326
'key': key_type,
337327
'description': ' '.join(description_parts),
338-
'default': column_default
328+
'default': column_default,
329+
'sample_values': sample_values
339330
}
340331

341332
return columns_info

api/loaders/postgres_loader.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import datetime
55
import decimal
66
import logging
7-
from typing import AsyncGenerator, Tuple, Dict, Any, List
7+
from typing import AsyncGenerator, Dict, Any, List, Tuple
88

99
import psycopg2
1010
from psycopg2 import sql
@@ -51,37 +51,29 @@ class PostgresLoader(BaseLoader):
5151
]
5252

5353
@staticmethod
54-
def _execute_count_query(cursor, table_name: str, col_name: str) -> Tuple[int, int]:
54+
def _execute_sample_query(
55+
cursor, table_name: str, col_name: str, sample_size: int = 3
56+
) -> List[Any]:
5557
"""
56-
Execute query to get total count and distinct count for a column.
57-
PostgreSQL implementation returning counts from tuple-style results.
58+
Execute query to get random sample values for a column.
59+
PostgreSQL implementation using ORDER BY RANDOM() for random sampling.
5860
"""
5961
query = sql.SQL("""
60-
SELECT COUNT(*) AS total_count,
61-
COUNT(DISTINCT {col}) AS distinct_count
62-
FROM {table};
62+
SELECT {col}
63+
FROM (
64+
SELECT DISTINCT {col}
65+
FROM {table}
66+
WHERE {col} IS NOT NULL
67+
) AS distinct_vals
68+
ORDER BY RANDOM()
69+
LIMIT %s;
6370
""").format(
6471
col=sql.Identifier(col_name),
6572
table=sql.Identifier(table_name)
6673
)
67-
cursor.execute(query)
68-
output = cursor.fetchall()
69-
first_result = output[0]
70-
return first_result[0], first_result[1]
71-
72-
@staticmethod
73-
def _execute_distinct_query(cursor, table_name: str, col_name: str) -> List[Any]:
74-
"""
75-
Execute query to get distinct values for a column.
76-
PostgreSQL implementation handling tuple-style results.
77-
"""
78-
query = sql.SQL("SELECT DISTINCT {col} FROM {table};").format(
79-
col=sql.Identifier(col_name),
80-
table=sql.Identifier(table_name)
81-
)
82-
cursor.execute(query)
83-
distinct_results = cursor.fetchall()
84-
return [row[0] for row in distinct_results if row[0] is not None]
74+
cursor.execute(query, (sample_size,))
75+
sample_results = cursor.fetchall()
76+
return [row[0] for row in sample_results if row[0] is not None]
8577

8678
@staticmethod
8779
def _serialize_value(value):
@@ -279,18 +271,18 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]:
279271
if column_default:
280272
description_parts.append(f"(Default: {column_default})")
281273

282-
# Add distinct values if applicable
283-
distinct_values_desc = PostgresLoader.extract_distinct_values_for_column(
274+
# Extract sample values for the column (stored separately, not in description)
275+
sample_values = PostgresLoader.extract_sample_values_for_column(
284276
cursor, table_name, col_name
285277
)
286-
description_parts.extend(distinct_values_desc)
287278

288279
columns_info[col_name] = {
289280
'type': data_type,
290281
'null': is_nullable,
291282
'key': key_type,
292283
'description': ' '.join(description_parts),
293-
'default': column_default
284+
'default': column_default,
285+
'sample_values': sample_values
294286
}
295287

296288

0 commit comments

Comments
 (0)