@@ -178,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
178178 if not inspector .has_table (table_name ):
179179 raise ValueError (f"Table '{ table_name } ' not found in database" )
180180
181- def get_schema (self , * , categorical_threshold : int ) -> str :
181+ def get_schema (self , * , categorical_threshold : int ) -> str : # noqa: PLR0912
182182 """
183183 Generate schema information from database table.
184184
@@ -191,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:
191191
192192 schema = [f"Table: { self ._table_name } " , "Columns:" ]
193193
194+ # Build a single query to get all column statistics
195+ select_parts = []
196+ numeric_columns = []
197+ text_columns = []
198+
194199 for col in columns :
195- # Get SQL type name
196- sql_type = self ._get_sql_type_name (col ["type" ])
197- column_info = [f"- { col ['name' ]} ({ sql_type } )" ]
200+ col_name = col ["name" ]
198201
199- # For numeric columns, try to get range
202+ # Check if column is numeric
200203 if isinstance (
201204 col ["type" ],
202205 (
@@ -208,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
208211 sqltypes .DateTime ,
209212 sqltypes .BigInteger ,
210213 sqltypes .SmallInteger ,
211- # sqltypes.Interval,
212214 ),
213215 ):
214- try :
215- query = text (
216- f"SELECT MIN({ col ['name' ]} ), MAX({ col ['name' ]} ) FROM { self ._table_name } " ,
217- )
218- with self ._get_connection () as conn :
219- result = conn .execute (query ).fetchone ()
220- if result and result [0 ] is not None and result [1 ] is not None :
221- column_info .append (f" Range: { result [0 ]} to { result [1 ]} " )
222- except Exception : # noqa: S110
223- pass # Silently skip range info if query fails
224-
225- # For string/text columns, check if categorical
216+ numeric_columns .append (col_name )
217+ select_parts .extend (
218+ [
219+ f"MIN({ col_name } ) as { col_name } __min" ,
220+ f"MAX({ col_name } ) as { col_name } __max" ,
221+ ],
222+ )
223+
224+ # Check if column is text/string
226225 elif isinstance (
227226 col ["type" ],
228227 (sqltypes .String , sqltypes .Text , sqltypes .Enum ),
229228 ):
230- try :
231- count_query = text (
232- f"SELECT COUNT(DISTINCT { col ['name' ]} ) FROM { self ._table_name } " ,
233- )
229+ text_columns .append (col_name )
230+ select_parts .append (
231+ f"COUNT(DISTINCT { col_name } ) as { col_name } __distinct_count" ,
232+ )
233+
234+ # Execute single query to get all statistics
235+ column_stats = {}
236+ if select_parts :
237+ try :
238+ stats_query = text (
239+ f"SELECT { ', ' .join (select_parts )} FROM { self ._table_name } " ,
240+ )
241+ with self ._get_connection () as conn :
242+ result = conn .execute (stats_query ).fetchone ()
243+ if result :
244+ # Convert result to dict for easier access
245+ column_stats = dict (zip (result ._fields , result ))
246+ except Exception : # noqa: S110
247+ pass # Fall back to no statistics if query fails
248+
249+ # Get categorical values for text columns that are below threshold
250+ categorical_values = {}
251+ text_cols_to_query = []
252+ for col_name in text_columns :
253+ distinct_count_key = f"{ col_name } __distinct_count"
254+ if (
255+ distinct_count_key in column_stats
256+ and column_stats [distinct_count_key ]
257+ and column_stats [distinct_count_key ] <= categorical_threshold
258+ ):
259+ text_cols_to_query .append (col_name )
260+
261+ # Get categorical values in a single query if needed
262+ if text_cols_to_query :
263+ try :
264+ # Build UNION query for all categorical columns
265+ union_parts = [
266+ f"SELECT '{ col_name } ' as column_name, { col_name } as value "
267+ f"FROM { self ._table_name } WHERE { col_name } IS NOT NULL "
268+ f"GROUP BY { col_name } "
269+ for col_name in text_cols_to_query
270+ ]
271+
272+ if union_parts :
273+ categorical_query = text (" UNION ALL " .join (union_parts ))
234274 with self ._get_connection () as conn :
235- distinct_count = conn .execute (count_query ).scalar ()
236- if distinct_count and distinct_count <= categorical_threshold :
237- values_query = text (
238- f"SELECT DISTINCT { col ['name' ]} FROM { self ._table_name } "
239- f"WHERE { col ['name' ]} IS NOT NULL" ,
240- )
241- values = [
242- str (row [0 ])
243- for row in conn .execute (values_query ).fetchall ()
244- ]
245- values_str = ", " .join ([f"'{ v } '" for v in values ])
246- column_info .append (f" Categorical values: { values_str } " )
247- except Exception : # noqa: S110
248- pass # Silently skip categorical info if query fails
275+ results = conn .execute (categorical_query ).fetchall ()
276+ for row in results :
277+ col_name , value = row
278+ if col_name not in categorical_values :
279+ categorical_values [col_name ] = []
280+ categorical_values [col_name ].append (str (value ))
281+ except Exception : # noqa: S110
282+ pass # Skip categorical values if query fails
283+
284+ # Build schema description using collected statistics
285+ for col in columns :
286+ col_name = col ["name" ]
287+ sql_type = self ._get_sql_type_name (col ["type" ])
288+ column_info = [f"- { col_name } ({ sql_type } )" ]
289+
290+ # Add range info for numeric columns
291+ if col_name in numeric_columns :
292+ min_key = f"{ col_name } __min"
293+ max_key = f"{ col_name } __max"
294+ if (
295+ min_key in column_stats
296+ and max_key in column_stats
297+ and column_stats [min_key ] is not None
298+ and column_stats [max_key ] is not None
299+ ):
300+ column_info .append (
301+ f" Range: { column_stats [min_key ]} to { column_stats [max_key ]} " ,
302+ )
303+
304+ # Add categorical values for text columns
305+ elif col_name in categorical_values :
306+ values = categorical_values [col_name ]
307+ # Remove duplicates and sort
308+ unique_values = sorted (set (values ))
309+ values_str = ", " .join ([f"'{ v } '" for v in unique_values ])
310+ column_info .append (f" Categorical values: { values_str } " )
249311
250312 schema .extend (column_info )
251313
0 commit comments