Skip to content

Commit d238029

Browse files
committed
feat: Deny queries that contain potentially undesirable keywords
1 parent 0dd5d3c commit d238029

File tree

6 files changed

+399
-0
lines changed

6 files changed

+399
-0
lines changed

pkg-py/src/querychat/_datasource.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from sqlalchemy import inspect, text
1010
from sqlalchemy.sql import sqltypes
1111

12+
from ._utils import check_query
13+
1214
if TYPE_CHECKING:
1315
from narwhals.stable.v1.typing import IntoFrame
1416
from sqlalchemy.engine import Connection, Engine
@@ -261,6 +263,7 @@ def execute_query(self, query: str) -> pd.DataFrame:
261263
Query results as pandas DataFrame
262264
263265
"""
266+
check_query(query)
264267
return self._conn.execute(query).df()
265268

266269
def test_query(
@@ -523,6 +526,7 @@ def execute_query(self, query: str) -> pd.DataFrame:
523526
Query results as pandas DataFrame
524527
525528
"""
529+
check_query(query)
526530
with self._get_connection() as conn:
527531
return pd.read_sql_query(text(query), conn)
528532

pkg-py/src/querychat/_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,86 @@
11
from __future__ import annotations
22

33
import os
4+
import re
45
import warnings
56
from contextlib import contextmanager
67
from typing import TYPE_CHECKING, Literal, Optional
78

89
import narwhals.stable.v1 as nw
910

11+
12+
class UnsafeQueryError(ValueError):
13+
"""Raised when a query contains an unsafe/write operation."""
14+
15+
16+
def check_query(query: str) -> None:
17+
"""
18+
Check if a SQL query appears to be a non-read-only (write) operation.
19+
20+
Raises UnsafeQueryError if the query starts with a dangerous keyword.
21+
22+
Two categories of keywords are checked:
23+
24+
- Always blocked: DELETE, TRUNCATE, CREATE, DROP, ALTER, GRANT, REVOKE,
25+
EXEC, EXECUTE, CALL
26+
- Blocked unless QUERYCHAT_ENABLE_UPDATE_QUERIES=true: INSERT, UPDATE,
27+
MERGE, REPLACE, UPSERT
28+
29+
Parameters
30+
----------
31+
query
32+
The SQL query string to check
33+
34+
Raises
35+
------
36+
UnsafeQueryError
37+
If the query starts with a disallowed keyword
38+
39+
"""
40+
# Normalize: newlines/tabs -> space, collapse multiple spaces, trim, uppercase
41+
normalized = re.sub(r"[\r\n\t]+", " ", query)
42+
normalized = re.sub(r" +", " ", normalized)
43+
normalized = normalized.strip().upper()
44+
45+
# Always blocked - destructive/schema/admin operations
46+
always_blocked = [
47+
"DELETE",
48+
"TRUNCATE",
49+
"CREATE",
50+
"DROP",
51+
"ALTER",
52+
"GRANT",
53+
"REVOKE",
54+
"EXEC",
55+
"EXECUTE",
56+
"CALL",
57+
]
58+
59+
# Blocked unless escape hatch enabled - data modification
60+
update_keywords = ["INSERT", "UPDATE", "MERGE", "REPLACE", "UPSERT"]
61+
62+
# Check always-blocked keywords first
63+
always_pattern = r"^(" + "|".join(always_blocked) + r")\b"
64+
match = re.match(always_pattern, normalized)
65+
if match:
66+
raise UnsafeQueryError(
67+
f"Query appears to contain a disallowed operation: {match.group(1)}. "
68+
"Only SELECT queries are allowed."
69+
)
70+
71+
# Check update keywords (can be enabled via envvar)
72+
enable_updates = os.environ.get("QUERYCHAT_ENABLE_UPDATE_QUERIES", "").lower()
73+
if enable_updates not in ("true", "1", "yes"):
74+
update_pattern = r"^(" + "|".join(update_keywords) + r")\b"
75+
match = re.match(update_pattern, normalized)
76+
if match:
77+
raise UnsafeQueryError(
78+
f"Query appears to contain an update operation: {match.group(1)}. "
79+
"Only SELECT queries are allowed. "
80+
"Set QUERYCHAT_ENABLE_UPDATE_QUERIES=true to allow update queries."
81+
)
82+
83+
1084
if TYPE_CHECKING:
1185
from narwhals.stable.v1.typing import IntoFrame
1286

pkg-py/tests/test_datasource.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66
import pytest
77
from querychat._datasource import DataFrameSource, SQLAlchemySource
8+
from querychat._utils import UnsafeQueryError, check_query
89
from querychat.types import MissingColumnsError
910
from sqlalchemy import create_engine, text
1011

@@ -387,3 +388,113 @@ def test_test_query_error_message_format(test_db_engine):
387388
assert "Query result missing required columns" in error_message
388389
assert "The query must return all original table columns" in error_message
389390
assert "Original columns:" in error_message
391+
392+
393+
# Tests for check_query() function
394+
395+
396+
def test_check_query_allows_valid_select():
397+
"""Test that check_query allows valid SELECT queries."""
398+
check_query("SELECT * FROM test_table")
399+
check_query("select * from test_table")
400+
check_query(" SELECT * FROM test_table ")
401+
check_query("\nSELECT * FROM test_table\n")
402+
403+
404+
def test_check_query_blocks_always_blocked_keywords():
405+
"""Test that check_query blocks always-blocked keywords."""
406+
always_blocked = [
407+
"DELETE",
408+
"TRUNCATE",
409+
"CREATE",
410+
"DROP",
411+
"ALTER",
412+
"GRANT",
413+
"REVOKE",
414+
"EXEC",
415+
"EXECUTE",
416+
"CALL",
417+
]
418+
419+
for keyword in always_blocked:
420+
with pytest.raises(UnsafeQueryError, match="disallowed operation"):
421+
check_query(f"{keyword} something")
422+
423+
424+
def test_check_query_blocks_update_keywords_by_default():
425+
"""Test that check_query blocks update keywords by default."""
426+
update_keywords = ["INSERT", "UPDATE", "MERGE", "REPLACE", "UPSERT"]
427+
428+
for keyword in update_keywords:
429+
with pytest.raises(UnsafeQueryError, match="update operation"):
430+
check_query(f"{keyword} something")
431+
432+
433+
def test_check_query_normalizes_whitespace_and_case():
434+
"""Test that check_query normalizes whitespace and case."""
435+
with pytest.raises(UnsafeQueryError, match="disallowed"):
436+
check_query(" delete FROM table ")
437+
with pytest.raises(UnsafeQueryError, match="disallowed"):
438+
check_query("\n\nDELETE\n\nFROM table")
439+
with pytest.raises(UnsafeQueryError, match="disallowed"):
440+
check_query("\tDELETE\tFROM\ttable")
441+
with pytest.raises(UnsafeQueryError, match="disallowed"):
442+
check_query("DeLeTe FROM table")
443+
444+
445+
def test_check_query_escape_hatch_enables_update_keywords(monkeypatch):
446+
"""Test that escape hatch enables update keywords."""
447+
monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", "true")
448+
449+
# These should not raise
450+
check_query("INSERT INTO table VALUES (1)")
451+
check_query("UPDATE table SET x = 1")
452+
check_query("MERGE INTO table USING")
453+
check_query("REPLACE INTO table VALUES (1)")
454+
check_query("UPSERT INTO table VALUES (1)")
455+
456+
457+
def test_check_query_escape_hatch_does_not_enable_always_blocked(monkeypatch):
458+
"""Test that escape hatch does NOT enable always-blocked keywords."""
459+
monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", "true")
460+
461+
with pytest.raises(UnsafeQueryError, match="disallowed"):
462+
check_query("DELETE FROM table")
463+
with pytest.raises(UnsafeQueryError, match="disallowed"):
464+
check_query("DROP TABLE table")
465+
with pytest.raises(UnsafeQueryError, match="disallowed"):
466+
check_query("TRUNCATE TABLE table")
467+
468+
469+
def test_check_query_integrated_into_execute_query():
470+
"""Test that check_query is integrated into execute_query()."""
471+
test_df = pd.DataFrame(
472+
{
473+
"id": [1, 2, 3],
474+
"name": ["a", "b", "c"],
475+
"value": [10, 20, 30],
476+
}
477+
)
478+
479+
source = DataFrameSource(test_df, "test_table")
480+
481+
with pytest.raises(UnsafeQueryError, match="disallowed operation"):
482+
source.execute_query("DELETE FROM test_table")
483+
484+
with pytest.raises(UnsafeQueryError, match="update operation"):
485+
source.execute_query("INSERT INTO test_table VALUES (1, 'a', 1)")
486+
487+
source.cleanup()
488+
489+
490+
def test_check_query_does_not_block_keywords_in_column_names():
491+
"""Test that keywords in column names or values are not blocked."""
492+
check_query("SELECT update_count FROM table")
493+
check_query("SELECT * FROM delete_logs")
494+
495+
496+
def test_check_query_escape_hatch_accepts_various_values(monkeypatch):
497+
"""Test that escape hatch accepts various truthy values."""
498+
for value in ["true", "TRUE", "1", "yes", "YES"]:
499+
monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", value)
500+
check_query("INSERT INTO table VALUES (1)") # Should not raise

pkg-r/R/DataSource.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ DBISource <- R6::R6Class(
337337
DBI::dbQuoteIdentifier(private$conn, self$table_name)
338338
)
339339
}
340+
341+
check_query(query)
340342
DBI::dbGetQuery(private$conn, query)
341343
},
342344

pkg-r/R/utils-check.R

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,108 @@ check_sql_table_name <- function(
6060

6161
invisible(NULL)
6262
}
63+
64+
65+
# SQL query validation --------------------------------------------------------
66+
67+
#' Check SQL query for disallowed operations
68+
#'
69+
#' Validates that a SQL query does not start with a disallowed operation.
70+
#' This is a simple safety check, not full query parsing.
71+
#'
72+
#' @param query The SQL query string to check
73+
#' @param ... These dots are for future extensions and must be empty.
74+
#' @param arg Argument name to use in error messages
75+
#' @param call Calling environment for error messages
76+
#'
77+
#' @details
78+
#' Two categories of keywords are checked:
79+
#'
80+
#' **Always blocked** (no escape hatch):
81+
#' DELETE, TRUNCATE, CREATE, DROP, ALTER, GRANT, REVOKE, EXEC, EXECUTE, CALL
82+
#'
83+
#' **Blocked unless escape hatch enabled**:
84+
#' INSERT, UPDATE, MERGE, REPLACE, UPSERT
85+
#'
86+
#' The escape hatch can be enabled via
87+
#' `options(querychat.enable_update_queries = TRUE)` or by setting the
88+
#' environment variable `QUERYCHAT_ENABLE_UPDATE_QUERIES=true`.
89+
#'
90+
#' @return Invisibly returns `NULL` if validation passes. Otherwise throws an
91+
#' error.
92+
#'
93+
#' @noRd
94+
check_query <- function(
95+
query,
96+
...,
97+
arg = caller_arg(query),
98+
call = caller_env()
99+
) {
100+
check_dots_empty()
101+
check_string(query, arg = arg, call = call)
102+
103+
# Normalize: newlines/tabs -> space, collapse multiple spaces, trim, uppercase
104+
normalized <- query
105+
normalized <- gsub("[\r\n\t]+", " ", normalized)
106+
normalized <- gsub(" +", " ", normalized)
107+
normalized <- trimws(normalized)
108+
normalized <- toupper(normalized)
109+
110+
# Always blocked - destructive/schema/admin operations
111+
always_blocked <- c(
112+
"DELETE",
113+
"TRUNCATE",
114+
"CREATE",
115+
"DROP",
116+
"ALTER",
117+
"GRANT",
118+
"REVOKE",
119+
"EXEC",
120+
"EXECUTE",
121+
"CALL"
122+
)
123+
124+
# Blocked unless escape hatch enabled - data modification
125+
update_keywords <- c("INSERT", "UPDATE", "MERGE", "REPLACE", "UPSERT")
126+
127+
# Check always-blocked keywords first
128+
always_pattern <- paste0("^(", paste(always_blocked, collapse = "|"), ")\\b")
129+
if (grepl(always_pattern, normalized)) {
130+
matched <- regmatches(normalized, regexpr(always_pattern, normalized))
131+
cli::cli_abort(
132+
c(
133+
"Query appears to contain a disallowed operation: {matched}",
134+
"i" = "Only SELECT queries are allowed."
135+
),
136+
call = call
137+
)
138+
}
139+
140+
# Check update keywords (can be enabled via option or envvar)
141+
enable_updates <- isTRUE(getOption("querychat.enable_update_queries", FALSE))
142+
if (!enable_updates) {
143+
envvar <- Sys.getenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", "")
144+
enable_updates <- tolower(envvar) %in% c("true", "1", "yes")
145+
}
146+
147+
if (!enable_updates) {
148+
update_pattern <- paste0(
149+
"^(",
150+
paste(update_keywords, collapse = "|"),
151+
")\\b"
152+
)
153+
if (grepl(update_pattern, normalized)) {
154+
matched <- regmatches(normalized, regexpr(update_pattern, normalized))
155+
cli::cli_abort(
156+
c(
157+
"Query appears to contain an update operation: {matched}",
158+
"i" = "Only SELECT queries are allowed.",
159+
"i" = "Set {.code options(querychat.enable_update_queries = TRUE)} or {.envvar QUERYCHAT_ENABLE_UPDATE_QUERIES=true} to allow update queries."
160+
),
161+
call = call
162+
)
163+
}
164+
}
165+
166+
invisible(NULL)
167+
}

0 commit comments

Comments
 (0)