Skip to content

Commit 78fbf93

Browse files
committed
added tests
1 parent 902741c commit 78fbf93

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

.github/workflows/smoke-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ jobs:
3131
pre-commit install
3232
pre-commit run --all-files
3333
34-
# - name: Run tests
35-
# run: pytest
34+
- name: Run tests
35+
run: pytest

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import sys
2+
from pathlib import Path
3+
4+
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

tests/test_indexing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from any_chatbot.indexing import _tbl, build_duckdb_and_summary_cards
2+
from pathlib import Path
3+
4+
5+
def test_tbl_cleaning():
6+
assert _tbl("my table!") == "my_table"
7+
assert _tbl("123name") == "t_123name"
8+
assert _tbl("##$") == "t_"
9+
10+
11+
def test_build_duckdb_and_summary_cards_csv(tmp_path: Path):
12+
csv_path = tmp_path / "data.csv"
13+
csv_path.write_text("a,b\n1,2\n3,4")
14+
db_path = tmp_path / "db.duckdb"
15+
16+
cards = build_duckdb_and_summary_cards(tmp_path, db_path)
17+
assert len(cards) == 1
18+
19+
card = cards[0]
20+
assert card.metadata["source_type"] == "table_summary"
21+
assert card.metadata["db_path"] == str(db_path)
22+
assert card.metadata["table"] == "data"
23+
assert "TABLE CARD" in card.page_content
24+
assert "a:BIGINT" in card.page_content

tests/test_tools.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from any_chatbot.tools import initialize_retrieve_tool, is_safe_sql
2+
from langchain.schema import Document
3+
4+
5+
class DummyStore:
6+
def __init__(self):
7+
self.calls = []
8+
9+
def similarity_search(self, query: str, k: int = 5, filter=None):
10+
self.calls.append((query, k, filter))
11+
return [Document(page_content="foo", metadata=filter)]
12+
13+
14+
def test_initialize_retrieve_tool_invokes_vectorstore():
15+
store = DummyStore()
16+
retrieve = initialize_retrieve_tool(store)
17+
text, docs = retrieve.func("hello", "text_chunk")
18+
19+
assert store.calls == [("hello", 5, {"source_type": "text_chunk"})]
20+
assert docs[0].metadata["source_type"] == "text_chunk"
21+
assert "foo" in text
22+
23+
24+
def test_is_safe_sql():
25+
assert is_safe_sql("SELECT * FROM tbl")
26+
assert is_safe_sql("SELECT updated_at FROM tbl")
27+
assert not is_safe_sql("DROP TABLE tbl")
28+
assert not is_safe_sql("UPDATE tbl SET a=1")

0 commit comments

Comments
 (0)