|
6 | 6 |
|
7 | 7 |
|
8 | 8 | def test_get_assets_returns_correct_artifact(): |
9 | | - db_dir, dump_dir, logs_dir = setup_environment() |
10 | | - module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir) |
| 9 | + dump_dir, _ = setup_environment() |
| 10 | + module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") |
11 | 11 | artifact = module.get_assets() |
12 | 12 | assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" |
13 | 13 |
|
14 | 14 |
|
15 | 15 | def test_dump_and_load_preserves_model_state(): |
16 | | - db_dir, dump_dir, logs_dir = setup_environment() |
17 | | - module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir) |
| 16 | + dump_dir, _ = setup_environment() |
| 17 | + module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") |
18 | 18 |
|
19 | 19 | utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"] |
20 | 20 | labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1] |
21 | 21 | module.fit(utterances, labels) |
22 | 22 |
|
23 | 23 | dump_path = Path(dump_dir) |
24 | | - dump_path.mkdir(parents=True, exist_ok=True) |
25 | | - module.dump(str(dump_path)) |
| 24 | + module.dump(dump_path) |
26 | 25 |
|
27 | | - loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir) |
28 | | - loaded_module.load(str(dump_path)) |
| 26 | + loaded_module = RetrievalEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") |
| 27 | + loaded_module.load(dump_path) |
29 | 28 |
|
30 | 29 | assert loaded_module.embedder_name == module.embedder_name |
31 | 30 |
|
|
0 commit comments