|
17 | 17 | import torch |
18 | 18 | import pandas as pd |
19 | 19 | import tempfile |
20 | | -import pytest |
21 | 20 |
|
22 | | -import graphium |
23 | 21 | from graphium.utils.fs import rm, exists, get_size |
24 | 22 | from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule |
25 | 23 |
|
|
29 | 27 |
|
30 | 28 | class test_DataModule(ut.TestCase): |
31 | 29 |
|
32 | | - @pytest.fixture |
33 | | - def _setup_tmp_path(self, tmp_path): |
34 | | - self.tmp_path = tmp_path |
35 | | - |
36 | 30 | def test_ogb_datamodule(self): |
37 | 31 | # other datasets are too large to be tested |
38 | 32 | dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] |
@@ -386,7 +380,6 @@ def test_datamodule_multiple_data_files(self): |
386 | 380 |
|
387 | 381 | self.assertEqual(len(ds.train_ds), 20) |
388 | 382 |
|
389 | | - @pytest.mark.usefixtures("_setup_tmp_path") |
390 | 383 | def test_splits_file(self): |
391 | 384 | # Test single CSV files |
392 | 385 | csv_file = "tests/data/micro_ZINC_shard_1.csv" |
@@ -432,7 +425,7 @@ def test_splits_file(self): |
432 | 425 |
|
433 | 426 | try: |
434 | 427 | # Create a TemporaryFile to save the splits, and test the datamodule |
435 | | - temp_file = tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path) |
| 428 | + temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) |
436 | 429 |
|
437 | 430 | # Save the splits |
438 | 431 | torch.save(splits, temp_file) |
@@ -479,7 +472,8 @@ def test_splits_file(self): |
479 | 472 | np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) |
480 | 473 |
|
481 | 474 | finally: |
482 | | - temp_file.close() |
| 475 | + temp_file.close() |
| 476 | + os.unlink(temp_file.name) |
483 | 477 |
|
484 | 478 |
|
485 | 479 | if __name__ == "__main__": |
|
0 commit comments