Skip to content

Commit 043667a

Browse files
npelikanjcheng5
andauthored
feat(r+py): Generic datasources (#28)
Co-authored-by: Joe Cheng <[email protected]> closes #29
1 parent ba8dda8 commit 043667a

38 files changed

+2201
-427
lines changed

.github/workflows/py-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ jobs:
3737
- name: 📦 Install the project
3838
run: uv sync --python ${{matrix.config.python-version }} --all-extras --all-groups
3939

40-
# - name: 🧪 Check tests
41-
# run: make py-check-tests
40+
- name: 🧪 Check tests
41+
run: make py-check-tests
4242

4343
- name: 📝 Check types
4444
run: make py-check-types

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,13 @@ po/*~
250250

251251
# RStudio Connect folder
252252
rsconnect/
253+
python-package/CLAUDE.md
253254

254255
uv.lock
255256
_dev
256257

258+
# R ignores
257259
/.quarto/
260+
.Rprofile
261+
renv/
262+
renv.lock

Makefile

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,11 @@ py-check-tox: ## [py] Run python 3.9 - 3.12 checks with tox
123123
@echo "🔄 Running tests and type checking with tox for Python 3.9--3.12"
124124
uv run tox run-parallel
125125

126-
# .PHONY: py-check-tests
127-
# py-check-tests: ## [py] Run python tests
128-
# @echo ""
129-
# @echo "🧪 Running tests with pytest"
130-
# uv run playwright install
131-
# uv run pytest
126+
.PHONY: py-check-tests
127+
py-check-tests: ## [py] Run python tests
128+
@echo ""
129+
@echo "🧪 Running tests with pytest"
130+
uv run pytest
132131

133132
.PHONY: py-check-types
134133
py-check-types: ## [py] Run python type checks

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ querychat does not have direct access to the raw data; it can _only_ read or fil
3636
- **Transparency:** querychat always displays the SQL to the user, so it can be vetted instead of blindly trusted.
3737
- **Reproducibility:** The SQL query can be easily copied and reused.
3838

39-
Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has a surprising number of statistical functions.
39+
Currently, querychat uses DuckDB for its SQL engine when working with data frames. For database sources, it uses the native SQL dialect of the connected database.
4040

4141
## Language-specific Documentation
4242

4343
For detailed information on how to use querychat in your preferred language, see the language-specific READMEs:
4444

4545
- [R Documentation](pkg-r/README.md)
46-
- [Python Documentation](pkg-py/README.md)
46+
- [Python Documentation](pkg-py/README.md)

pkg-py/examples/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def data_table():
4949

5050

5151
# Create Shiny app
52-
app = App(app_ui, server)
52+
app = App(app_ui, server)

pkg-py/src/querychat/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from querychat.querychat import init, sidebar, system_prompt
2-
from querychat.querychat import mod_server as server
3-
from querychat.querychat import mod_ui as ui
1+
from querychat.querychat import (
2+
init,
3+
sidebar,
4+
system_prompt,
5+
)
6+
from querychat.querychat import (
7+
mod_server as server,
8+
)
9+
from querychat.querychat import (
10+
mod_ui as ui,
11+
)
412

513
__all__ = ["init", "server", "sidebar", "system_prompt", "ui"]

pkg-py/src/querychat/datasource.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
178178
if not inspector.has_table(table_name):
179179
raise ValueError(f"Table '{table_name}' not found in database")
180180

181-
def get_schema(self, *, categorical_threshold: int) -> str:
181+
def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
182182
"""
183183
Generate schema information from database table.
184184
@@ -191,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:
191191

192192
schema = [f"Table: {self._table_name}", "Columns:"]
193193

194+
# Build a single query to get all column statistics
195+
select_parts = []
196+
numeric_columns = []
197+
text_columns = []
198+
194199
for col in columns:
195-
# Get SQL type name
196-
sql_type = self._get_sql_type_name(col["type"])
197-
column_info = [f"- {col['name']} ({sql_type})"]
200+
col_name = col["name"]
198201

199-
# For numeric columns, try to get range
202+
# Check if column is numeric
200203
if isinstance(
201204
col["type"],
202205
(
@@ -208,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
208211
sqltypes.DateTime,
209212
sqltypes.BigInteger,
210213
sqltypes.SmallInteger,
211-
# sqltypes.Interval,
212214
),
213215
):
214-
try:
215-
query = text(
216-
f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}",
217-
)
218-
with self._get_connection() as conn:
219-
result = conn.execute(query).fetchone()
220-
if result and result[0] is not None and result[1] is not None:
221-
column_info.append(f" Range: {result[0]} to {result[1]}")
222-
except Exception: # noqa: S110
223-
pass # Silently skip range info if query fails
224-
225-
# For string/text columns, check if categorical
216+
numeric_columns.append(col_name)
217+
select_parts.extend(
218+
[
219+
f"MIN({col_name}) as {col_name}__min",
220+
f"MAX({col_name}) as {col_name}__max",
221+
],
222+
)
223+
224+
# Check if column is text/string
226225
elif isinstance(
227226
col["type"],
228227
(sqltypes.String, sqltypes.Text, sqltypes.Enum),
229228
):
230-
try:
231-
count_query = text(
232-
f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}",
233-
)
229+
text_columns.append(col_name)
230+
select_parts.append(
231+
f"COUNT(DISTINCT {col_name}) as {col_name}__distinct_count",
232+
)
233+
234+
# Execute single query to get all statistics
235+
column_stats = {}
236+
if select_parts:
237+
try:
238+
stats_query = text(
239+
f"SELECT {', '.join(select_parts)} FROM {self._table_name}",
240+
)
241+
with self._get_connection() as conn:
242+
result = conn.execute(stats_query).fetchone()
243+
if result:
244+
# Convert result to dict for easier access
245+
column_stats = dict(zip(result._fields, result))
246+
except Exception: # noqa: S110
247+
pass # Fall back to no statistics if query fails
248+
249+
# Get categorical values for text columns that are below threshold
250+
categorical_values = {}
251+
text_cols_to_query = []
252+
for col_name in text_columns:
253+
distinct_count_key = f"{col_name}__distinct_count"
254+
if (
255+
distinct_count_key in column_stats
256+
and column_stats[distinct_count_key]
257+
and column_stats[distinct_count_key] <= categorical_threshold
258+
):
259+
text_cols_to_query.append(col_name)
260+
261+
# Get categorical values in a single query if needed
262+
if text_cols_to_query:
263+
try:
264+
# Build UNION query for all categorical columns
265+
union_parts = [
266+
f"SELECT '{col_name}' as column_name, {col_name} as value "
267+
f"FROM {self._table_name} WHERE {col_name} IS NOT NULL "
268+
f"GROUP BY {col_name}"
269+
for col_name in text_cols_to_query
270+
]
271+
272+
if union_parts:
273+
categorical_query = text(" UNION ALL ".join(union_parts))
234274
with self._get_connection() as conn:
235-
distinct_count = conn.execute(count_query).scalar()
236-
if distinct_count and distinct_count <= categorical_threshold:
237-
values_query = text(
238-
f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
239-
f"WHERE {col['name']} IS NOT NULL",
240-
)
241-
values = [
242-
str(row[0])
243-
for row in conn.execute(values_query).fetchall()
244-
]
245-
values_str = ", ".join([f"'{v}'" for v in values])
246-
column_info.append(f" Categorical values: {values_str}")
247-
except Exception: # noqa: S110
248-
pass # Silently skip categorical info if query fails
275+
results = conn.execute(categorical_query).fetchall()
276+
for row in results:
277+
col_name, value = row
278+
if col_name not in categorical_values:
279+
categorical_values[col_name] = []
280+
categorical_values[col_name].append(str(value))
281+
except Exception: # noqa: S110
282+
pass # Skip categorical values if query fails
283+
284+
# Build schema description using collected statistics
285+
for col in columns:
286+
col_name = col["name"]
287+
sql_type = self._get_sql_type_name(col["type"])
288+
column_info = [f"- {col_name} ({sql_type})"]
289+
290+
# Add range info for numeric columns
291+
if col_name in numeric_columns:
292+
min_key = f"{col_name}__min"
293+
max_key = f"{col_name}__max"
294+
if (
295+
min_key in column_stats
296+
and max_key in column_stats
297+
and column_stats[min_key] is not None
298+
and column_stats[max_key] is not None
299+
):
300+
column_info.append(
301+
f" Range: {column_stats[min_key]} to {column_stats[max_key]}",
302+
)
303+
304+
# Add categorical values for text columns
305+
elif col_name in categorical_values:
306+
values = categorical_values[col_name]
307+
# Remove duplicates and sort
308+
unique_values = sorted(set(values))
309+
values_str = ", ".join([f"'{v}'" for v in unique_values])
310+
column_info.append(f" Categorical values: {values_str}")
249311

250312
schema.extend(column_info)
251313

pkg-py/src/querychat/querychat.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,12 @@ def __getitem__(self, key: str) -> Any:
118118
backwards compatibility only; new code should use the attributes
119119
directly instead.
120120
"""
121-
if key == "chat": # noqa: SIM116
122-
return self.chat
123-
elif key == "sql":
124-
return self.sql
125-
elif key == "title":
126-
return self.title
127-
elif key == "df":
128-
return self.df
129-
130-
raise KeyError(
131-
f"`QueryChat` does not have a key `'{key}'`. "
132-
"Use the attributes `chat`, `sql`, `title`, or `df` instead.",
133-
)
121+
return {
122+
"chat": self.chat,
123+
"sql": self.sql,
124+
"title": self.title,
125+
"df": self.df,
126+
}.get(key)
134127

135128

136129
def system_prompt(

pkg-py/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)