|
1 | 1 | import os |
2 | | -import tempfile |
| 2 | +from pathlib import Path |
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 |
|
|
36 | 36 |
|
37 | 37 |
|
38 | 38 | @pytest.mark.integration_test() |
39 | | -def test_train_single_table(): |
40 | | - with tempfile.TemporaryDirectory() as save_dir: |
41 | | - os.makedirs(os.path.join(save_dir, "models")) |
| 39 | +def test_train_single_table(tmp_path: Path): |
| 40 | + os.makedirs(tmp_path / "models") |
| 41 | + configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG} |
42 | 42 |
|
43 | | - configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG} |
| 43 | + tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/") |
| 44 | + models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") |
44 | 45 |
|
45 | | - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/single_table/") |
46 | | - models = clava_training(tables, relation_order, save_dir, configs, device="cpu") |
47 | | - |
48 | | - assert models |
| 46 | + assert models |
49 | 47 |
|
50 | 48 |
|
51 | 49 | @pytest.mark.integration_test() |
52 | | -def test_train_multi_table(): |
53 | | - with tempfile.TemporaryDirectory() as save_dir: |
54 | | - os.makedirs(os.path.join(save_dir, "models")) |
55 | | - |
56 | | - configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG} |
| 50 | +def test_train_multi_table(tmp_path: Path): |
| 51 | + os.makedirs(tmp_path / "models") |
| 52 | + configs = {"clustering": CLUSTERING_CONFIG, "diffusion": DIFFUSION_CONFIG, "classifier": CLASSIFIER_CONFIG} |
57 | 53 |
|
58 | | - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") |
59 | | - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs) |
60 | | - models = clava_training(tables, relation_order, save_dir, configs, device="cpu") |
| 54 | + tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") |
| 55 | + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) |
| 56 | + models = clava_training(tables, relation_order, tmp_path, configs, device="cpu") |
61 | 57 |
|
62 | | - assert models |
| 58 | + assert models |
63 | 59 |
|
64 | 60 |
|
65 | 61 | @pytest.mark.integration_test() |
66 | | -def test_clustering_reload(): |
67 | | - with tempfile.TemporaryDirectory() as save_dir: |
68 | | - configs = {"clustering": CLUSTERING_CONFIG} |
| 62 | +def test_clustering_reload(tmp_path: Path): |
| 63 | + os.makedirs(tmp_path / "models") |
| 64 | + configs = {"clustering": CLUSTERING_CONFIG} |
69 | 65 |
|
70 | | - tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") |
71 | | - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs) |
| 66 | + tables, relation_order, dataset_meta = load_multi_table("tests/integration/data/multi_table/") |
| 67 | + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) |
72 | 68 |
|
73 | | - assert all_group_lengths_prob_dicts |
| 69 | + assert all_group_lengths_prob_dicts |
74 | 70 |
|
75 | | - # loading from previously saved clustering |
76 | | - tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs) |
| 71 | + # loading from previously saved clustering |
| 72 | + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, configs) |
77 | 73 |
|
78 | | - assert all_group_lengths_prob_dicts |
| 74 | + assert all_group_lengths_prob_dicts |
0 commit comments