Skip to content

Commit 01d643d

Browse files
test cases fixes for helpers
1 parent 152feae commit 01d643d

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

tests/services/__init__.py

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Sample schema fixtures for testing schema extraction and caching."""
22

3+
import pytest
4+
35
from datu.base.base_connector import SchemaInfo, TableInfo
46
from datu.schema_extractor.schema_cache import SchemaGlossary
57

@@ -52,3 +54,17 @@ def raw_schema_dict():
5254
}
5355
],
5456
}
57+
58+
59+
@pytest.fixture
60+
def sample_schema():
61+
# Return a function that can generate schemas with parameters
62+
def _factory(timestamp: float = 1234567890.0):
63+
return SchemaTestFixtures.sample_schema(timestamp=timestamp)
64+
65+
return _factory
66+
67+
68+
@pytest.fixture
69+
def raw_schema_dict():
70+
return SchemaTestFixtures.raw_schema_dict()

tests/services/test_graph_rag.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from datu.services.schema_rag import SchemaGraphBuilder, SchemaRAG, SchemaTripleExtractor
1111

12-
from tests.helpers.sample_schemas import SchemaTestFixtures
13-
1412
TEST_GRAPH_DIR = "test_graph_rag"
1513

1614

@@ -25,18 +23,17 @@ def clean_test_graph_cache():
2523
shutil.rmtree(TEST_GRAPH_DIR)
2624

2725

28-
def test_init_with_dict_schema():
26+
def test_init_with_dict_schema(raw_schema_dict):
2927
"""Test SchemaGraphBuilder initialization with a raw schema dictionary."""
30-
schema_dict = SchemaTestFixtures.raw_schema_dict()
31-
extractor = SchemaTripleExtractor(schema_dict)
28+
extractor = SchemaTripleExtractor(raw_schema_dict)
3229
extractor.create_schema_triples()
3330
assert extractor.timestamp == 1234567890.0
3431
assert len(extractor.schema_profiles) == 1
3532

3633

37-
def test_extract_triples_output():
34+
def test_extract_triples_output(sample_schema):
3835
"""Test that triples are extracted correctly from schema objects."""
39-
schema_profiles = SchemaTestFixtures.sample_schema()
36+
schema_profiles = sample_schema()
4037
extractor = SchemaTripleExtractor(schema_profiles)
4138
extractor.paths["triples"] = os.path.join(TEST_GRAPH_DIR, "triples.json")
4239
extractor.paths["meta"] = os.path.join(TEST_GRAPH_DIR, "meta.json")
@@ -61,9 +58,10 @@ def test_is_graph_outdated_returns_true_for_missing_files(tmp_path):
6158
assert extractor.is_rag_outdated() is True
6259

6360

64-
def test_initialize_graph_rebuild_and_cache():
61+
@pytest.mark.parametrize("timestamp", [9999.0])
62+
def test_initialize_graph_rebuild_and_cache(sample_schema, timestamp):
6563
"""Test graph initialization rebuilds and caches the graph correctly."""
66-
schema = SchemaTestFixtures.sample_schema(timestamp=9999.0)
64+
schema = sample_schema(timestamp=timestamp)
6765
extractor = SchemaTripleExtractor(schema)
6866
extractor.create_schema_triples()
6967
builder = SchemaGraphBuilder(triples=extractor.triples, is_rag_outdated=True)
@@ -84,10 +82,9 @@ def test_initialize_graph_rebuild_and_cache():
8482

8583

8684
@pytest.mark.requires_service
87-
def test_schema_rag_run_query_returns_filtered_schema_dict():
85+
def test_schema_rag_run_query_returns_filtered_schema_dict(sample_schema):
8886
"""Test SchemaRAG end-to-end run_query method returns filtered schema."""
89-
schema = SchemaTestFixtures.sample_schema()
90-
rag = SchemaRAG(schema)
87+
rag = SchemaRAG(sample_schema)
9188
result = rag.run_query(["List all customer orders"])
9289
assert isinstance(result, dict)
9390
assert "schema_info" in result

0 commit comments

Comments
 (0)