11from __future__ import annotations
22
3- from typing import ClassVar , Protocol
3+ from typing import TYPE_CHECKING , ClassVar , Protocol
44
55import duckdb
66import narwhals as nw
77import pandas as pd
88from sqlalchemy import inspect , text
9- from sqlalchemy .engine import Connection , Engine
109from sqlalchemy .sql import sqltypes
1110
11+ if TYPE_CHECKING :
12+ from sqlalchemy .engine import Connection , Engine
13+
1214
1315class DataSource (Protocol ):
1416 db_engine : ClassVar [str ]
@@ -176,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
176178 if not inspector .has_table (table_name ):
177179 raise ValueError (f"Table '{ table_name } ' not found in database" )
178180
179- def get_schema (self , * , categorical_threshold : int ) -> str :
181+ def get_schema (self , * , categorical_threshold : int ) -> str : # noqa: PLR0912
180182 """
181183 Generate schema information from database table.
182184
@@ -189,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:
189191
190192 schema = [f"Table: { self ._table_name } " , "Columns:" ]
191193
194+ # Build a single query to get all column statistics
195+ select_parts = []
196+ numeric_columns = []
197+ text_columns = []
198+
192199 for col in columns :
193- # Get SQL type name
194- sql_type = self ._get_sql_type_name (col ["type" ])
195- column_info = [f"- { col ['name' ]} ({ sql_type } )" ]
200+ col_name = col ["name" ]
196201
197- # For numeric columns, try to get range
202+ # Check if column is numeric
198203 if isinstance (
199204 col ["type" ],
200205 (
@@ -206,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
206211 sqltypes .DateTime ,
207212 sqltypes .BigInteger ,
208213 sqltypes .SmallInteger ,
209- # sqltypes.Interval,
210214 ),
211215 ):
212- try :
213- query = text (
214- f"SELECT MIN({ col ['name' ]} ), MAX({ col ['name' ]} ) FROM { self ._table_name } " ,
215- )
216- with self ._get_connection () as conn :
217- result = conn .execute (query ).fetchone ()
218- if result and result [0 ] is not None and result [1 ] is not None :
219- column_info .append (f" Range: { result [0 ]} to { result [1 ]} " )
220- except Exception :
221- pass # Skip range info if query fails
222-
223- # For string/text columns, check if categorical
216+ numeric_columns .append (col_name )
217+ select_parts .extend (
218+ [
219+ f"MIN({ col_name } ) as { col_name } _min" ,
220+ f"MAX({ col_name } ) as { col_name } _max" ,
221+ ],
222+ )
223+
224+ # Check if column is text/string
224225 elif isinstance (
225226 col ["type" ],
226227 (sqltypes .String , sqltypes .Text , sqltypes .Enum ),
227228 ):
228- try :
229- count_query = text (
230- f"SELECT COUNT(DISTINCT { col ['name' ]} ) FROM { self ._table_name } " ,
231- )
229+ text_columns .append (col_name )
230+ select_parts .append (
231+ f"COUNT(DISTINCT { col_name } ) as { col_name } _distinct_count" ,
232+ )
233+
234+ # Execute single query to get all statistics
235+ column_stats = {}
236+ if select_parts :
237+ try :
238+ stats_query = text (
239+ f"SELECT { ', ' .join (select_parts )} FROM { self ._table_name } " , # noqa: S608
240+ )
241+ with self ._get_connection () as conn :
242+ result = conn .execute (stats_query ).fetchone ()
243+ if result :
244+ # Convert result to dict for easier access
245+ column_stats = dict (zip (result ._fields , result ))
246+ except Exception : # noqa: S110
247+ pass # Fall back to no statistics if query fails
248+
249+ # Get categorical values for text columns that are below threshold
250+ categorical_values = {}
251+ text_cols_to_query = []
252+ for col_name in text_columns :
253+ distinct_count_key = f"{ col_name } _distinct_count"
254+ if (
255+ distinct_count_key in column_stats
256+ and column_stats [distinct_count_key ]
257+ and column_stats [distinct_count_key ] <= categorical_threshold
258+ ):
259+ text_cols_to_query .append (col_name )
260+
261+ # Get categorical values in a single query if needed
262+ if text_cols_to_query :
263+ try :
264+ # Build UNION query for all categorical columns
265+ union_parts = [
266+ f"SELECT '{ col_name } ' as column_name, { col_name } as value " # noqa: S608
267+ f"FROM { self ._table_name } WHERE { col_name } IS NOT NULL "
268+ f"GROUP BY { col_name } "
269+ for col_name in text_cols_to_query
270+ ]
271+
272+ if union_parts :
273+ categorical_query = text (" UNION ALL " .join (union_parts ))
232274 with self ._get_connection () as conn :
233- distinct_count = conn .execute (count_query ).scalar ()
234- if distinct_count and distinct_count <= categorical_threshold :
235- values_query = text (
236- f"SELECT DISTINCT { col ['name' ]} FROM { self ._table_name } "
237- f"WHERE { col ['name' ]} IS NOT NULL" ,
238- )
239- values = [
240- str (row [0 ])
241- for row in conn .execute (values_query ).fetchall ()
242- ]
243- values_str = ", " .join ([f"'{ v } '" for v in values ])
244- column_info .append (f" Categorical values: { values_str } " )
245- except Exception :
246- pass # Skip categorical info if query fails
275+ results = conn .execute (categorical_query ).fetchall ()
276+ for row in results :
277+ col_name , value = row
278+ if col_name not in categorical_values :
279+ categorical_values [col_name ] = []
280+ categorical_values [col_name ].append (str (value ))
281+ except Exception : # noqa: S110
282+ pass # Skip categorical values if query fails
283+
284+ # Build schema description using collected statistics
285+ for col in columns :
286+ col_name = col ["name" ]
287+ sql_type = self ._get_sql_type_name (col ["type" ])
288+ column_info = [f"- { col_name } ({ sql_type } )" ]
289+
290+ # Add range info for numeric columns
291+ if col_name in numeric_columns :
292+ min_key = f"{ col_name } _min"
293+ max_key = f"{ col_name } _max"
294+ if (
295+ min_key in column_stats
296+ and max_key in column_stats
297+ and column_stats [min_key ] is not None
298+ and column_stats [max_key ] is not None
299+ ):
300+ column_info .append (
301+ f" Range: { column_stats [min_key ]} to { column_stats [max_key ]} " ,
302+ )
303+
304+ # Add categorical values for text columns
305+ elif col_name in categorical_values :
306+ values = categorical_values [col_name ]
307+ # Remove duplicates and sort
308+ unique_values = sorted (set (values ))
309+ values_str = ", " .join ([f"'{ v } '" for v in unique_values ])
310+ column_info .append (f" Categorical values: { values_str } " )
247311
248312 schema .extend (column_info )
249313
@@ -271,9 +335,9 @@ def get_data(self) -> pd.DataFrame:
271335 The complete dataset as a pandas DataFrame
272336
273337 """
274- return self .execute_query (f"SELECT * FROM { self ._table_name } " )
338+ return self .execute_query (f"SELECT * FROM { self ._table_name } " ) # noqa: S608
275339
276- def _get_sql_type_name (self , type_ : sqltypes .TypeEngine ) -> str :
340+ def _get_sql_type_name (self , type_ : sqltypes .TypeEngine ) -> str : # noqa: PLR0911
277341 """Convert SQLAlchemy type to SQL type name."""
278342 if isinstance (type_ , sqltypes .Integer ):
279343 return "INTEGER"
0 commit comments