Skip to content

Commit e2e6887

Browse files
authored
Improve TransformersDocumentClassifier tests (#3270)
1 parent 24d4591 commit e2e6887

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/nodes/test_document_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def test_document_classifier_details(document_classifier):
2929
results = document_classifier.predict(documents=docs)
3030
for doc in results:
3131
assert "details" in doc.meta["classification"]
32-
assert len(doc.meta["classification"]["details"]) == 2 # top_k = 2
32+
if document_classifier.top_k is not None:
33+
assert len(doc.meta["classification"]["details"]) == document_classifier.top_k
3334

3435

3536
@pytest.mark.integration
@@ -82,7 +83,7 @@ def test_zero_shot_document_classifier_details(zero_shot_document_classifier):
8283
results = zero_shot_document_classifier.predict(documents=docs)
8384
for doc in results:
8485
assert "details" in doc.meta["classification"]
85-
assert len(doc.meta["classification"]["details"]) == 2 # n_labels = 2
86+
assert set(doc.meta["classification"]["details"].keys()) == set(zero_shot_document_classifier.labels)
8687

8788

8889
@pytest.mark.integration

0 commit comments

Comments
 (0)