|
5 | 5 | import pandas as pd |
6 | 6 | import pytest |
7 | 7 | from querychat._datasource import DataFrameSource, SQLAlchemySource |
| 8 | +from querychat._utils import UnsafeQueryError, check_query |
8 | 9 | from querychat.types import MissingColumnsError |
9 | 10 | from sqlalchemy import create_engine, text |
10 | 11 |
|
@@ -387,3 +388,113 @@ def test_test_query_error_message_format(test_db_engine): |
387 | 388 | assert "Query result missing required columns" in error_message |
388 | 389 | assert "The query must return all original table columns" in error_message |
389 | 390 | assert "Original columns:" in error_message |
| 391 | + |
| 392 | + |
| 393 | +# Tests for check_query() function |
| 394 | + |
| 395 | + |
| 396 | +def test_check_query_allows_valid_select(): |
| 397 | + """Test that check_query allows valid SELECT queries.""" |
| 398 | + check_query("SELECT * FROM test_table") |
| 399 | + check_query("select * from test_table") |
| 400 | + check_query(" SELECT * FROM test_table ") |
| 401 | + check_query("\nSELECT * FROM test_table\n") |
| 402 | + |
| 403 | + |
| 404 | +def test_check_query_blocks_always_blocked_keywords(): |
| 405 | + """Test that check_query blocks always-blocked keywords.""" |
| 406 | + always_blocked = [ |
| 407 | + "DELETE", |
| 408 | + "TRUNCATE", |
| 409 | + "CREATE", |
| 410 | + "DROP", |
| 411 | + "ALTER", |
| 412 | + "GRANT", |
| 413 | + "REVOKE", |
| 414 | + "EXEC", |
| 415 | + "EXECUTE", |
| 416 | + "CALL", |
| 417 | + ] |
| 418 | + |
| 419 | + for keyword in always_blocked: |
| 420 | + with pytest.raises(UnsafeQueryError, match="disallowed operation"): |
| 421 | + check_query(f"{keyword} something") |
| 422 | + |
| 423 | + |
| 424 | +def test_check_query_blocks_update_keywords_by_default(): |
| 425 | + """Test that check_query blocks update keywords by default.""" |
| 426 | + update_keywords = ["INSERT", "UPDATE", "MERGE", "REPLACE", "UPSERT"] |
| 427 | + |
| 428 | + for keyword in update_keywords: |
| 429 | + with pytest.raises(UnsafeQueryError, match="update operation"): |
| 430 | + check_query(f"{keyword} something") |
| 431 | + |
| 432 | + |
| 433 | +def test_check_query_normalizes_whitespace_and_case(): |
| 434 | + """Test that check_query normalizes whitespace and case.""" |
| 435 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 436 | + check_query(" delete FROM table ") |
| 437 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 438 | + check_query("\n\nDELETE\n\nFROM table") |
| 439 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 440 | + check_query("\tDELETE\tFROM\ttable") |
| 441 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 442 | + check_query("DeLeTe FROM table") |
| 443 | + |
| 444 | + |
| 445 | +def test_check_query_escape_hatch_enables_update_keywords(monkeypatch): |
| 446 | + """Test that escape hatch enables update keywords.""" |
| 447 | + monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", "true") |
| 448 | + |
| 449 | + # These should not raise |
| 450 | + check_query("INSERT INTO table VALUES (1)") |
| 451 | + check_query("UPDATE table SET x = 1") |
| 452 | + check_query("MERGE INTO table USING") |
| 453 | + check_query("REPLACE INTO table VALUES (1)") |
| 454 | + check_query("UPSERT INTO table VALUES (1)") |
| 455 | + |
| 456 | + |
| 457 | +def test_check_query_escape_hatch_does_not_enable_always_blocked(monkeypatch): |
| 458 | + """Test that escape hatch does NOT enable always-blocked keywords.""" |
| 459 | + monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", "true") |
| 460 | + |
| 461 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 462 | + check_query("DELETE FROM table") |
| 463 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 464 | + check_query("DROP TABLE table") |
| 465 | + with pytest.raises(UnsafeQueryError, match="disallowed"): |
| 466 | + check_query("TRUNCATE TABLE table") |
| 467 | + |
| 468 | + |
| 469 | +def test_check_query_integrated_into_execute_query(): |
| 470 | + """Test that check_query is integrated into execute_query().""" |
| 471 | + test_df = pd.DataFrame( |
| 472 | + { |
| 473 | + "id": [1, 2, 3], |
| 474 | + "name": ["a", "b", "c"], |
| 475 | + "value": [10, 20, 30], |
| 476 | + } |
| 477 | + ) |
| 478 | + |
| 479 | + source = DataFrameSource(test_df, "test_table") |
| 480 | + |
| 481 | + with pytest.raises(UnsafeQueryError, match="disallowed operation"): |
| 482 | + source.execute_query("DELETE FROM test_table") |
| 483 | + |
| 484 | + with pytest.raises(UnsafeQueryError, match="update operation"): |
| 485 | + source.execute_query("INSERT INTO test_table VALUES (1, 'a', 1)") |
| 486 | + |
| 487 | + source.cleanup() |
| 488 | + |
| 489 | + |
| 490 | +def test_check_query_does_not_block_keywords_in_column_names(): |
| 491 | + """Test that keywords in column names or values are not blocked.""" |
| 492 | + check_query("SELECT update_count FROM table") |
| 493 | + check_query("SELECT * FROM delete_logs") |
| 494 | + |
| 495 | + |
| 496 | +def test_check_query_escape_hatch_accepts_various_values(monkeypatch): |
| 497 | + """Test that escape hatch accepts various truthy values.""" |
| 498 | + for value in ["true", "TRUE", "1", "yes", "YES"]: |
| 499 | + monkeypatch.setenv("QUERYCHAT_ENABLE_UPDATE_QUERIES", value) |
| 500 | + check_query("INSERT INTO table VALUES (1)") # Should not raise |
0 commit comments