Skip to content

Commit a25afe2

Browse files
committed
fix: fixed retrieval test
1 parent af33c6f commit a25afe2

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tests/modules/retrieval/test_retrieval.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,25 @@
66

77

88
def test_get_assets_returns_correct_artifact():
9-
db_dir, dump_dir, logs_dir = setup_environment()
10-
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
9+
dump_dir, _ = setup_environment()
10+
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
1111
artifact = module.get_assets()
1212
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
1313

1414

1515
def test_dump_and_load_preserves_model_state():
16-
db_dir, dump_dir, logs_dir = setup_environment()
17-
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
16+
dump_dir, _ = setup_environment()
17+
module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
1818

1919
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
2020
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
2121
module.fit(utterances, labels)
2222

2323
dump_path = Path(dump_dir)
24-
dump_path.mkdir(parents=True, exist_ok=True)
25-
module.dump(str(dump_path))
24+
module.dump(dump_path)
2625

27-
loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir)
28-
loaded_module.load(str(dump_path))
26+
loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
27+
loaded_module.load(dump_path)
2928

3029
assert loaded_module.embedder_name == module.embedder_name
3130

0 commit comments

Comments
 (0)