Skip to content

Commit ba50a37

Browse files
authored
feat: Improve query security and integrity checks (#180)
* feat(update_dashboard): Check that all columns are returned from query * feat: Deny queries that contain potentially undesirable keywords * `devtools::document()` (GitHub Actions) * chore: Check query before running in `test_query()` too * chore(py): Export UnsafeQueryError from types * chore: Add news/changelog item * docs(py): Document raises unsafe query error --------- Co-authored-by: gadenbuie <[email protected]>
1 parent 750de80 commit ba50a37

File tree

14 files changed

+931
-8
lines changed

14 files changed

+931
-8
lines changed

pkg-py/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
* The tools used in a `QueryChat` chatbot are now configurable. Use the new `tools` parameter of `QueryChat()` to select either or both `"query"` or `"update"` tools. Choose `tools=["update"]` if you only want QueryChat to be able to update the dashboard (useful when you want to be 100% certain that the LLM will not see _any_ raw data). (#168)
1919

20+
### Improvements
21+
22+
* The update tool now requires that the SQL query returns all columns from the original data source, ensuring that the dashboard can display the complete data frame after filtering or sorting. If the query does not return all columns, an informative error message will be provided. (#180)
23+
24+
* Obvious SQL keywords that lead to data modification (e.g., `INSERT`, `UPDATE`, `DELETE`, `DROP`, etc.) are now prohibited in queries run via the query tool or update tool, to prevent accidental data changes. If such keywords are detected, an informative error message will be provided. (#180)
25+
2026
## [0.3.0] - 2025-12-10
2127

2228
### Breaking Changes

pkg-py/src/querychat/_datasource.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
from sqlalchemy import inspect, text
1010
from sqlalchemy.sql import sqltypes
1111

12+
from ._utils import check_query
13+
1214
if TYPE_CHECKING:
1315
from narwhals.stable.v1.typing import IntoFrame
1416
from sqlalchemy.engine import Connection, Engine
1517

1618

19+
class MissingColumnsError(ValueError):
20+
"""Raised when a query result is missing required columns."""
21+
22+
1723
class DataSource(ABC):
1824
"""
1925
An abstract class defining the interface for data sources used by QueryChat.
@@ -70,6 +76,34 @@ def execute_query(self, query: str) -> pd.DataFrame:
7076
"""
7177
...
7278

79+
@abstractmethod
80+
def test_query(
81+
self, query: str, *, require_all_columns: bool = False
82+
) -> pd.DataFrame:
83+
"""
84+
Test SQL query by fetching only one row.
85+
86+
Parameters
87+
----------
88+
query
89+
SQL query to test
90+
require_all_columns
91+
If True, validates that result includes all original table columns.
92+
Additional computed columns are allowed.
93+
94+
Returns
95+
-------
96+
:
97+
Query results as a pandas DataFrame with at most one row
98+
99+
Raises
100+
------
101+
MissingColumnsError
102+
If require_all_columns is True and result is missing required columns
103+
104+
"""
105+
...
106+
73107
@abstractmethod
74108
def get_data(self) -> pd.DataFrame:
75109
"""
@@ -136,6 +170,9 @@ def __init__(self, df: IntoFrame, table_name: str):
136170
SET lock_configuration = true;
137171
""")
138172

173+
# Store original column names for validation
174+
self._colnames = list(self._df.columns)
175+
139176
def get_db_type(self) -> str:
140177
"""
141178
Get the database type.
@@ -225,9 +262,60 @@ def execute_query(self, query: str) -> pd.DataFrame:
225262
:
226263
Query results as pandas DataFrame
227264
265+
Raises
266+
------
267+
UnsafeQueryError
268+
If the query starts with a disallowed SQL operation
269+
228270
"""
271+
check_query(query)
229272
return self._conn.execute(query).df()
230273

274+
def test_query(
275+
self, query: str, *, require_all_columns: bool = False
276+
) -> pd.DataFrame:
277+
"""
278+
Test query by fetching only one row.
279+
280+
Parameters
281+
----------
282+
query
283+
SQL query to test
284+
require_all_columns
285+
If True, validates that result includes all original table columns
286+
287+
Returns
288+
-------
289+
:
290+
Query results with at most one row
291+
292+
Raises
293+
------
294+
UnsafeQueryError
295+
If the query starts with a disallowed SQL operation
296+
MissingColumnsError
297+
If require_all_columns is True and result is missing required columns
298+
299+
"""
300+
check_query(query)
301+
result = self._conn.execute(f"{query} LIMIT 1").df()
302+
303+
if require_all_columns:
304+
result_columns = set(result.columns)
305+
original_columns_set = set(self._colnames)
306+
missing_columns = original_columns_set - result_columns
307+
308+
if missing_columns:
309+
missing_list = ", ".join(f"'{col}'" for col in sorted(missing_columns))
310+
original_list = ", ".join(f"'{col}'" for col in self._colnames)
311+
raise MissingColumnsError(
312+
f"Query result missing required columns: {missing_list}. "
313+
f"The query must return all original table columns. "
314+
f"Original columns: {original_list}"
315+
)
316+
317+
return result
318+
231319
def get_data(self) -> pd.DataFrame:
232320
"""
233321
Return the unfiltered data as a DataFrame.
@@ -283,6 +371,10 @@ def __init__(self, engine: Engine, table_name: str):
283371
if not inspector.has_table(table_name):
284372
raise ValueError(f"Table '{table_name}' not found in database")
285373

374+
# Store original column names for validation
375+
columns_info = inspector.get_columns(table_name)
376+
self._colnames = [col["name"] for col in columns_info]
377+
286378
def get_db_type(self) -> str:
287379
"""
288380
Get the database type.
@@ -441,10 +533,70 @@ def execute_query(self, query: str) -> pd.DataFrame:
441533
:
442534
Query results as pandas DataFrame
443535
536+
Raises
537+
------
538+
UnsafeQueryError
539+
If the query starts with a disallowed SQL operation
540+
444541
"""
542+
check_query(query)
445543
with self._get_connection() as conn:
446544
return pd.read_sql_query(text(query), conn)
447545

546+
def test_query(
547+
self, query: str, *, require_all_columns: bool = False
548+
) -> pd.DataFrame:
549+
"""
550+
Test query by fetching only one row.
551+
552+
Parameters
553+
----------
554+
query
555+
SQL query to test
556+
require_all_columns
557+
If True, validates that result includes all original table columns
558+
559+
Returns
560+
-------
561+
:
562+
Query results with at most one row
563+
564+
Raises
565+
------
566+
UnsafeQueryError
567+
If the query starts with a disallowed SQL operation
568+
MissingColumnsError
569+
If require_all_columns is True and result is missing required columns
570+
571+
"""
572+
check_query(query)
573+
with self._get_connection() as conn:
574+
# Use pandas read_sql_query with limit to get at most one row
575+
limit_query = f"SELECT * FROM ({query}) AS subquery LIMIT 1"
576+
try:
577+
df = pd.read_sql_query(text(limit_query), conn)
578+
except Exception:
579+
# If LIMIT syntax doesn't work, fall back to regular read and take first row
580+
df = pd.read_sql_query(text(query), conn).head(1)
581+
582+
if require_all_columns:
583+
result_columns = set(df.columns)
584+
original_columns_set = set(self._colnames)
585+
missing_columns = original_columns_set - result_columns
586+
587+
if missing_columns:
588+
missing_list = ", ".join(
589+
f"'{col}'" for col in sorted(missing_columns)
590+
)
591+
original_list = ", ".join(f"'{col}'" for col in self._colnames)
592+
raise MissingColumnsError(
593+
f"Query result missing required columns: {missing_list}. "
594+
f"The query must return all original table columns. "
595+
f"Original columns: {original_list}"
596+
)
597+
598+
return df
599+
448600
def get_data(self) -> pd.DataFrame:
449601
"""
450602
Return the unfiltered data as a DataFrame.

pkg-py/src/querychat/_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,86 @@
11
from __future__ import annotations
22

33
import os
4+
import re
45
import warnings
56
from contextlib import contextmanager
67
from typing import TYPE_CHECKING, Literal, Optional
78

89
import narwhals.stable.v1 as nw
910

11+
12+
class UnsafeQueryError(ValueError):
13+
"""Raised when a query contains an unsafe/write operation."""
14+
15+
16+
def check_query(query: str) -> None:
17+
"""
18+
Check if a SQL query appears to be a non-read-only (write) operation.
19+
20+
Raises UnsafeQueryError if the query starts with a dangerous keyword.
21+
22+
Two categories of keywords are checked:
23+
24+
- Always blocked: DELETE, TRUNCATE, CREATE, DROP, ALTER, GRANT, REVOKE,
25+
EXEC, EXECUTE, CALL
26+
- Blocked unless QUERYCHAT_ENABLE_UPDATE_QUERIES=true: INSERT, UPDATE,
27+
MERGE, REPLACE, UPSERT
28+
29+
Parameters
30+
----------
31+
query
32+
The SQL query string to check
33+
34+
Raises
35+
------
36+
UnsafeQueryError
37+
If the query starts with a disallowed keyword
38+
39+
"""
40+
# Normalize: newlines/tabs -> space, collapse multiple spaces, trim, uppercase
41+
normalized = re.sub(r"[\r\n\t]+", " ", query)
42+
normalized = re.sub(r" +", " ", normalized)
43+
normalized = normalized.strip().upper()
44+
45+
# Always blocked - destructive/schema/admin operations
46+
always_blocked = [
47+
"DELETE",
48+
"TRUNCATE",
49+
"CREATE",
50+
"DROP",
51+
"ALTER",
52+
"GRANT",
53+
"REVOKE",
54+
"EXEC",
55+
"EXECUTE",
56+
"CALL",
57+
]
58+
59+
# Blocked unless escape hatch enabled - data modification
60+
update_keywords = ["INSERT", "UPDATE", "MERGE", "REPLACE", "UPSERT"]
61+
62+
# Check always-blocked keywords first
63+
always_pattern = r"^(" + "|".join(always_blocked) + r")\b"
64+
match = re.match(always_pattern, normalized)
65+
if match:
66+
raise UnsafeQueryError(
67+
f"Query appears to contain a disallowed operation: {match.group(1)}. "
68+
"Only SELECT queries are allowed."
69+
)
70+
71+
# Check update keywords (can be enabled via envvar)
72+
enable_updates = os.environ.get("QUERYCHAT_ENABLE_UPDATE_QUERIES", "").lower()
73+
if enable_updates not in ("true", "1", "yes"):
74+
update_pattern = r"^(" + "|".join(update_keywords) + r")\b"
75+
match = re.match(update_pattern, normalized)
76+
if match:
77+
raise UnsafeQueryError(
78+
f"Query appears to contain an update operation: {match.group(1)}. "
79+
"Only SELECT queries are allowed. "
80+
"Set QUERYCHAT_ENABLE_UPDATE_QUERIES=true to allow update queries."
81+
)
82+
83+
1084
if TYPE_CHECKING:
1185
from narwhals.stable.v1.typing import IntoFrame
1286

pkg-py/src/querychat/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def update_dashboard(query: str, title: str) -> ContentToolResult:
7575

7676
try:
7777
# Test the query but don't execute it yet
78-
data_source.execute_query(query)
78+
data_source.test_query(query, require_all_columns=True)
7979

8080
# Add Apply Filter button
8181
button_html = f"""<button
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from .._datasource import DataFrameSource, DataSource, SQLAlchemySource # noqa: A005
1+
from .._datasource import ( # noqa: A005
2+
DataFrameSource,
3+
DataSource,
4+
MissingColumnsError,
5+
SQLAlchemySource,
6+
)
27
from .._querychat_module import ServerValues
8+
from .._utils import UnsafeQueryError
39
from ..tools import UpdateDashboardData
410

511
__all__ = (
612
"DataFrameSource",
713
"DataSource",
14+
"MissingColumnsError",
815
"SQLAlchemySource",
916
"ServerValues",
17+
"UnsafeQueryError",
1018
"UpdateDashboardData",
1119
)

0 commit comments

Comments
 (0)