Skip to content

Commit 2bb8133

Browse files
masciZanSarabogdankostic
authored
feat: add SQLDocumentStore tests (#3517)
* port SQL tests * cleanup document_store_tests.py from sql tests * leftover * Update .github/workflows/tests.yml Co-authored-by: Sara Zan <[email protected]> * review comments * Update test/document_stores/test_base.py Co-authored-by: bogdankostic <[email protected]> Co-authored-by: Sara Zan <[email protected]> Co-authored-by: bogdankostic <[email protected]>
1 parent 1a60e21 commit 2bb8133

File tree

6 files changed

+196
-84
lines changed

6 files changed

+196
-84
lines changed

.github/workflows/tests.yml

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ jobs:
141141
ES_JAVA_OPTS: "-Xms128m -Xmx256m"
142142
ports:
143143
- 9200:9200
144-
# env:
145-
# ELASTICSEARCH_HOST: "elasticsearch"
146144
steps:
147145
- uses: actions/checkout@v3
148146

@@ -154,7 +152,35 @@ jobs:
154152

155153
- name: Run tests
156154
run: |
157-
pytest -x -m "document_store and integration" test/document_stores/test_elasticsearch.py
155+
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_elasticsearch.py
156+
157+
- uses: act10ns/slack@v1
158+
with:
159+
status: ${{ job.status }}
160+
channel: '#haystack'
161+
if: failure() && github.repository_owner == 'deepset-ai' && github.ref == 'refs/heads/main'
162+
163+
integration-tests-sql:
164+
name: Integration / SQL / ${{ matrix.os }}
165+
needs:
166+
- unit-tests
167+
strategy:
168+
fail-fast: false
169+
matrix:
170+
os: [ubuntu-latest,macos-latest,windows-latest]
171+
runs-on: ${{ matrix.os }}
172+
steps:
173+
- uses: actions/checkout@v3
174+
175+
- name: Setup Python
176+
uses: ./.github/actions/python_cache/
177+
178+
- name: Install Haystack
179+
run: pip install -U .[sql]
180+
181+
- name: Run tests
182+
run: |
183+
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_sql.py
158184
159185
- uses: act10ns/slack@v1
160186
with:
@@ -179,8 +205,6 @@ jobs:
179205
ES_JAVA_OPTS: "-Xms128m -Xmx256m"
180206
ports:
181207
- 9200:9200
182-
# env:
183-
# OPENSEARCH_HOST: "opensearch"
184208
steps:
185209
- uses: actions/checkout@v3
186210

@@ -192,7 +216,7 @@ jobs:
192216

193217
- name: Run tests
194218
run: |
195-
pytest -x -m "document_store and integration" test/document_stores/test_opensearch.py
219+
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_opensearch.py
196220
197221
- uses: act10ns/slack@v1
198222
with:

haystack/document_stores/sql.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,16 +460,30 @@ def write_labels(self, labels, index=None, headers: Optional[Dict[str, str]] = N
460460
# self.write_documents(documents=[label.document], index=index, duplicate_documents="skip")
461461

462462
# TODO: Handle label meta data
463+
464+
# Sanitize fields to adhere to SQL constraints
465+
answer = label.answer
466+
if answer is not None:
467+
answer = answer.to_json()
468+
469+
no_answer = label.no_answer
470+
if label.no_answer is None:
471+
no_answer = False
472+
473+
document = label.document
474+
if document is not None:
475+
document = document.to_json()
476+
463477
label_orm = LabelORM(
464478
id=label.id,
465-
no_answer=label.no_answer,
479+
no_answer=no_answer,
466480
# document_id=label.document.id,
467-
document=label.document.to_json(),
481+
document=document,
468482
origin=label.origin,
469483
query=label.query,
470484
is_correct_answer=label.is_correct_answer,
471485
is_correct_document=label.is_correct_document,
472-
answer=label.answer.to_json(),
486+
answer=answer,
473487
pipeline_id=label.pipeline_id,
474488
index=index,
475489
)
@@ -576,11 +590,13 @@ def _convert_sql_row_to_document(self, row) -> Document:
576590
return document
577591

578592
def _convert_sql_row_to_label(self, row) -> Label:
579-
# doc = self._convert_sql_row_to_document(row.document)
593+
answer = row.answer
594+
if answer is not None:
595+
answer = Answer.from_json(answer)
580596

581597
label = Label(
582598
query=row.query,
583-
answer=Answer.from_json(row.answer), # type: ignore
599+
answer=answer,
584600
document=Document.from_json(row.document),
585601
is_correct_answer=row.is_correct_answer,
586602
is_correct_document=row.is_correct_document,

test/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,7 @@ def get_document_store(
10121012
recreate_index: bool = True,
10131013
): # cosine is default similarity as dot product is not supported by Weaviate
10141014
document_store: BaseDocumentStore
1015-
if document_store_type == "sql":
1016-
document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
1017-
1018-
elif document_store_type == "memory":
1015+
if document_store_type == "memory":
10191016
document_store = InMemoryDocumentStore(
10201017
return_embedding=True,
10211018
embedding_dim=embedding_dim,

test/document_stores/test_base.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def test_write_documents(self, ds, documents):
7070
ds.write_documents(documents)
7171
docs = ds.get_all_documents()
7272
assert len(docs) == len(documents)
73-
for i, doc in enumerate(docs):
74-
expected = documents[i]
75-
assert doc.id == expected.id
73+
expected_ids = set(doc.id for doc in documents)
74+
ids = set(doc.id for doc in docs)
75+
assert ids == expected_ids
7676

7777
@pytest.mark.integration
7878
def test_write_labels(self, ds, labels):
@@ -142,27 +142,41 @@ def test_get_all_documents_with_incorrect_filter_value(self, ds, documents):
142142
assert len(result) == 0
143143

144144
@pytest.mark.integration
145-
def test_extended_filter(self, ds, documents):
145+
def test_eq_filters(self, ds, documents):
146146
ds.write_documents(documents)
147147

148-
# Test comparison operators individually
149-
150148
result = ds.get_all_documents(filters={"year": {"$eq": "2020"}})
151149
assert len(result) == 3
152150
result = ds.get_all_documents(filters={"year": "2020"})
153151
assert len(result) == 3
154152

153+
@pytest.mark.integration
154+
def test_in_filters(self, ds, documents):
155+
ds.write_documents(documents)
156+
155157
result = ds.get_all_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}})
156158
assert len(result) == 6
157159
result = ds.get_all_documents(filters={"year": ["2020", "2021", "n.a."]})
158160
assert len(result) == 6
159161

162+
@pytest.mark.integration
163+
def test_ne_filters(self, ds, documents):
164+
ds.write_documents(documents)
165+
160166
result = ds.get_all_documents(filters={"year": {"$ne": "2020"}})
161167
assert len(result) == 6
162168

169+
@pytest.mark.integration
170+
def test_nin_filters(self, ds, documents):
171+
ds.write_documents(documents)
172+
163173
result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}})
164174
assert len(result) == 3
165175

176+
@pytest.mark.integration
177+
def test_comparison_filters(self, ds, documents):
178+
ds.write_documents(documents)
179+
166180
result = ds.get_all_documents(filters={"numbers": {"$gt": 0}})
167181
assert len(result) == 3
168182

@@ -175,11 +189,17 @@ def test_extended_filter(self, ds, documents):
175189
result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}})
176190
assert len(result) == 6
177191

178-
# Test compound filters
192+
@pytest.mark.integration
193+
def test_compound_filters(self, ds, documents):
194+
ds.write_documents(documents)
179195

180196
result = ds.get_all_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}})
181197
assert len(result) == 6
182198

199+
@pytest.mark.integration
200+
def test_simplified_filters(self, ds, documents):
201+
ds.write_documents(documents)
202+
183203
filters = {"$and": {"year": {"$lte": "2021", "$gte": "2020"}, "name": {"$in": ["name_0", "name_1"]}}}
184204
result = ds.get_all_documents(filters=filters)
185205
assert len(result) == 4
@@ -188,6 +208,9 @@ def test_extended_filter(self, ds, documents):
188208
result = ds.get_all_documents(filters=filters_simplified)
189209
assert len(result) == 4
190210

211+
@pytest.mark.integration
212+
def test_nested_condition_filters(self, ds, documents):
213+
ds.write_documents(documents)
191214
filters = {
192215
"$and": {
193216
"year": {"$lte": "2021", "$gte": "2020"},
@@ -223,8 +246,12 @@ def test_extended_filter(self, ds, documents):
223246
result = ds.get_all_documents(filters=filters_simplified)
224247
assert len(result) == 5
225248

226-
# Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore
227-
249+
@pytest.mark.integration
250+
def test_nested_condition_not_filters(self, ds, documents):
251+
"""
252+
Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore
253+
"""
254+
ds.write_documents(documents)
228255
filters = {
229256
"$not": {
230257
"$or": {
@@ -234,8 +261,9 @@ def test_extended_filter(self, ds, documents):
234261
}
235262
}
236263
result = ds.get_all_documents(filters=filters)
237-
docs_meta = result[0].meta["numbers"]
238264
assert len(result) == 3
265+
266+
docs_meta = result[0].meta["numbers"]
239267
assert [2, 4] == docs_meta
240268

241269
# Test same logical operator twice on same level
@@ -289,8 +317,8 @@ def test_duplicate_documents_skip(self, ds, documents):
289317
updated_docs.append(updated_d)
290318

291319
ds.write_documents(updated_docs, duplicate_documents="skip")
292-
result = ds.get_all_documents()
293-
assert result[0].meta["name"] == "name_0"
320+
for d in ds.get_all_documents():
321+
assert d.meta["name"] != "Updated"
294322

295323
@pytest.mark.integration
296324
def test_duplicate_documents_overwrite(self, ds, documents):

test/document_stores/test_document_store.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,6 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
172172
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
173173

174174

175-
def test_get_all_documents_with_correct_filters_legacy_sqlite(docs, tmp_path):
176-
document_store_with_docs = get_document_store("sql", tmp_path)
177-
document_store_with_docs.write_documents(docs)
178-
179-
document_store_with_docs.use_windowed_query = False
180-
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
181-
assert len(documents) == 1
182-
assert documents[0].meta["name"] == "filename2"
183-
184-
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]})
185-
assert len(documents) == 2
186-
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
187-
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
188-
189-
190175
def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
191176
documents = document_store_with_docs.get_all_documents(filters={"incorrect_meta_field": ["test2"]})
192177
assert len(documents) == 0
@@ -198,7 +183,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)
198183

199184

200185
# See test_pinecone.py
201-
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate", "memory"], indirect=True)
186+
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate", "memory"], indirect=True)
202187
def test_extended_filter(document_store_with_docs):
203188
# Test comparison operators individually
204189
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})
@@ -410,47 +395,6 @@ def test_write_document_meta(document_store: BaseDocumentStore):
410395
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"
411396

412397

413-
@pytest.mark.parametrize("document_store", ["sql"], indirect=True)
414-
def test_sql_write_document_invalid_meta(document_store: BaseDocumentStore):
415-
documents = [
416-
{
417-
"content": "dict_with_invalid_meta",
418-
"valid_meta_field": "test1",
419-
"invalid_meta_field": [1, 2, 3],
420-
"name": "filename1",
421-
"id": "1",
422-
},
423-
Document(
424-
content="document_object_with_invalid_meta",
425-
meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"},
426-
id="2",
427-
),
428-
]
429-
document_store.write_documents(documents)
430-
documents_in_store = document_store.get_all_documents()
431-
assert len(documents_in_store) == 2
432-
433-
assert document_store.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"}
434-
assert document_store.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}
435-
436-
437-
@pytest.mark.parametrize("document_store", ["sql"], indirect=True)
438-
def test_sql_write_different_documents_same_vector_id(document_store: BaseDocumentStore):
439-
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
440-
doc2 = {"content": "content 2", "name": "doc2", "id": "2", "vector_id": "vector_id"}
441-
442-
document_store.write_documents([doc1], index="index1")
443-
documents_in_index1 = document_store.get_all_documents(index="index1")
444-
assert len(documents_in_index1) == 1
445-
document_store.write_documents([doc2], index="index2")
446-
documents_in_index2 = document_store.get_all_documents(index="index2")
447-
assert len(documents_in_index2) == 1
448-
449-
document_store.write_documents([doc1], index="index3")
450-
with pytest.raises(Exception, match=r"(?i)unique"):
451-
document_store.write_documents([doc2], index="index3")
452-
453-
454398
def test_write_document_index(document_store: BaseDocumentStore):
455399
document_store.delete_index("haystack_test_one")
456400
document_store.delete_index("haystack_test_two")

0 commit comments

Comments
 (0)