Skip to content

Commit 43ee050

Browse files
authored
Optimizing Schema Queries (#26)
* this should significantly speed up schema generation * another speedup * ruff formatting * updating so formatting checks pass
1 parent a08764b commit 43ee050

File tree

5 files changed

+310
-49
lines changed

5 files changed

+310
-49
lines changed

python-package/pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ packages = ["src/querychat"]
4343
include = ["src/querychat", "LICENSE", "README.md"]
4444

4545
[tool.uv]
46-
dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"]
46+
dev-dependencies = [
47+
"ruff>=0.6.5",
48+
"pyright>=1.1.401",
49+
"tox-uv>=1.11.4",
50+
"pytest>=8.4.0",
51+
]
4752

4853
[tool.ruff]
4954
src = ["src/querychat"]

python-package/src/querychat/datasource.py

Lines changed: 104 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
22

3-
from typing import ClassVar, Protocol
3+
from typing import TYPE_CHECKING, ClassVar, Protocol
44

55
import duckdb
66
import narwhals as nw
77
import pandas as pd
88
from sqlalchemy import inspect, text
9-
from sqlalchemy.engine import Connection, Engine
109
from sqlalchemy.sql import sqltypes
1110

11+
if TYPE_CHECKING:
12+
from sqlalchemy.engine import Connection, Engine
13+
1214

1315
class DataSource(Protocol):
1416
db_engine: ClassVar[str]
@@ -176,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
176178
if not inspector.has_table(table_name):
177179
raise ValueError(f"Table '{table_name}' not found in database")
178180

179-
def get_schema(self, *, categorical_threshold: int) -> str:
181+
def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
180182
"""
181183
Generate schema information from database table.
182184
@@ -189,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:
189191

190192
schema = [f"Table: {self._table_name}", "Columns:"]
191193

194+
# Build a single query to get all column statistics
195+
select_parts = []
196+
numeric_columns = []
197+
text_columns = []
198+
192199
for col in columns:
193-
# Get SQL type name
194-
sql_type = self._get_sql_type_name(col["type"])
195-
column_info = [f"- {col['name']} ({sql_type})"]
200+
col_name = col["name"]
196201

197-
# For numeric columns, try to get range
202+
# Check if column is numeric
198203
if isinstance(
199204
col["type"],
200205
(
@@ -206,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
206211
sqltypes.DateTime,
207212
sqltypes.BigInteger,
208213
sqltypes.SmallInteger,
209-
# sqltypes.Interval,
210214
),
211215
):
212-
try:
213-
query = text(
214-
f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}",
215-
)
216-
with self._get_connection() as conn:
217-
result = conn.execute(query).fetchone()
218-
if result and result[0] is not None and result[1] is not None:
219-
column_info.append(f" Range: {result[0]} to {result[1]}")
220-
except Exception:
221-
pass # Skip range info if query fails
222-
223-
# For string/text columns, check if categorical
216+
numeric_columns.append(col_name)
217+
select_parts.extend(
218+
[
219+
f"MIN({col_name}) as {col_name}_min",
220+
f"MAX({col_name}) as {col_name}_max",
221+
],
222+
)
223+
224+
# Check if column is text/string
224225
elif isinstance(
225226
col["type"],
226227
(sqltypes.String, sqltypes.Text, sqltypes.Enum),
227228
):
228-
try:
229-
count_query = text(
230-
f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}",
231-
)
229+
text_columns.append(col_name)
230+
select_parts.append(
231+
f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count",
232+
)
233+
234+
# Execute single query to get all statistics
235+
column_stats = {}
236+
if select_parts:
237+
try:
238+
stats_query = text(
239+
f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608
240+
)
241+
with self._get_connection() as conn:
242+
result = conn.execute(stats_query).fetchone()
243+
if result:
244+
# Convert result to dict for easier access
245+
column_stats = dict(zip(result._fields, result))
246+
except Exception: # noqa: S110
247+
pass # Fall back to no statistics if query fails
248+
249+
# Get categorical values for text columns that are below threshold
250+
categorical_values = {}
251+
text_cols_to_query = []
252+
for col_name in text_columns:
253+
distinct_count_key = f"{col_name}_distinct_count"
254+
if (
255+
distinct_count_key in column_stats
256+
and column_stats[distinct_count_key]
257+
and column_stats[distinct_count_key] <= categorical_threshold
258+
):
259+
text_cols_to_query.append(col_name)
260+
261+
# Get categorical values in a single query if needed
262+
if text_cols_to_query:
263+
try:
264+
# Build UNION query for all categorical columns
265+
union_parts = [
266+
f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608
267+
f"FROM {self._table_name} WHERE {col_name} IS NOT NULL "
268+
f"GROUP BY {col_name}"
269+
for col_name in text_cols_to_query
270+
]
271+
272+
if union_parts:
273+
categorical_query = text(" UNION ALL ".join(union_parts))
232274
with self._get_connection() as conn:
233-
distinct_count = conn.execute(count_query).scalar()
234-
if distinct_count and distinct_count <= categorical_threshold:
235-
values_query = text(
236-
f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
237-
f"WHERE {col['name']} IS NOT NULL",
238-
)
239-
values = [
240-
str(row[0])
241-
for row in conn.execute(values_query).fetchall()
242-
]
243-
values_str = ", ".join([f"'{v}'" for v in values])
244-
column_info.append(f" Categorical values: {values_str}")
245-
except Exception:
246-
pass # Skip categorical info if query fails
275+
results = conn.execute(categorical_query).fetchall()
276+
for row in results:
277+
col_name, value = row
278+
if col_name not in categorical_values:
279+
categorical_values[col_name] = []
280+
categorical_values[col_name].append(str(value))
281+
except Exception: # noqa: S110
282+
pass # Skip categorical values if query fails
283+
284+
# Build schema description using collected statistics
285+
for col in columns:
286+
col_name = col["name"]
287+
sql_type = self._get_sql_type_name(col["type"])
288+
column_info = [f"- {col_name} ({sql_type})"]
289+
290+
# Add range info for numeric columns
291+
if col_name in numeric_columns:
292+
min_key = f"{col_name}_min"
293+
max_key = f"{col_name}_max"
294+
if (
295+
min_key in column_stats
296+
and max_key in column_stats
297+
and column_stats[min_key] is not None
298+
and column_stats[max_key] is not None
299+
):
300+
column_info.append(
301+
f" Range: {column_stats[min_key]} to {column_stats[max_key]}",
302+
)
303+
304+
# Add categorical values for text columns
305+
elif col_name in categorical_values:
306+
values = categorical_values[col_name]
307+
# Remove duplicates and sort
308+
unique_values = sorted(set(values))
309+
values_str = ", ".join([f"'{v}'" for v in unique_values])
310+
column_info.append(f" Categorical values: {values_str}")
247311

248312
schema.extend(column_info)
249313

@@ -271,9 +335,9 @@ def get_data(self) -> pd.DataFrame:
271335
The complete dataset as a pandas DataFrame
272336
273337
"""
274-
return self.execute_query(f"SELECT * FROM {self._table_name}")
338+
return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608
275339

276-
def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str:
340+
def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911
277341
"""Convert SQLAlchemy type to SQL type name."""
278342
if isinstance(type_, sqltypes.Integer):
279343
return "INTEGER"

python-package/src/querychat/querychat.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,12 @@ def __getitem__(self, key: str) -> Any:
126126
backwards compatibility only; new code should use the attributes
127127
directly instead.
128128
"""
129-
if key == "chat":
130-
return self.chat
131-
elif key == "sql":
132-
return self.sql
133-
elif key == "title":
134-
return self.title
135-
elif key == "df":
136-
return self.df
129+
return {
130+
"chat": self.chat,
131+
"sql": self.sql,
132+
"title": self.title,
133+
"df": self.df,
134+
}.get(key)
137135

138136

139137
def system_prompt(

python-package/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)