Skip to content

Commit 7a62f22

Browse files
committed
CR by David
1 parent 37f6136 commit 7a62f22

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

tests/integration/models/clavaddpm/test_model.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
import tempfile
2+
from pathlib import Path
33

44
import pytest
55

@@ -36,43 +36,39 @@
3636

3737

3838
@pytest.mark.integration_test()
39-
def test_train_single_table():
40-
with tempfile.TemporaryDirectory() as save_dir:
41-
os.makedirs(os.path.join(save_dir, "models"))
39+
def test_train_single_table(tmp_path: Path):
40+
os.makedirs(tmp_path / "models")
41+
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG}
4242

43-
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG}
43+
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/")
44+
models = clava_training(tables, relation_order, tmp_path, configs, device="cpu")
4445

45-
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/")
46-
models = clava_training(tables, relation_order, save_dir, configs, device="cpu")
47-
48-
assert models
46+
assert models
4947

5048

5149
@pytest.mark.integration_test()
52-
def test_train_multi_table():
53-
with tempfile.TemporaryDirectory() as save_dir:
54-
os.makedirs(os.path.join(save_dir, "models"))
55-
56-
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG}
50+
def test_train_multi_table(tmp_path: Path):
51+
os.makedirs(tmp_path / "models")
52+
configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG}
5753

58-
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/")
59-
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs)
60-
models = clava_training(tables, relation_order, save_dir, configs, device="cpu")
54+
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/")
55+
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs)
56+
models = clava_training(tables, relation_order, tmp_path, configs, device="cpu")
6157

62-
assert models
58+
assert models
6359

6460

6561
@pytest.mark.integration_test()
66-
def test_clustering_reload():
67-
with tempfile.TemporaryDirectory() as save_dir:
68-
configs = {"clustering": CLUSTERING_CONFIG}
62+
def test_clustering_reload(tmp_path: Path):
63+
os.makedirs(tmp_path / "models")
64+
configs = {"clustering": CLUSTERING_CONFIG}
6965

70-
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/")
71-
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs)
66+
tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/")
67+
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs)
7268

73-
assert all_group_lengths_prob_dicts
69+
assert all_group_lengths_prob_dicts
7470

75-
# loading from previously saved clustering
76-
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs)
71+
# loading from previously saved clustering
72+
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs)
7773

78-
assert all_group_lengths_prob_dicts
74+
assert all_group_lengths_prob_dicts

0 commit comments

Comments
 (0)