Skip to content

Commit e0a794e

Browse files
committed
add test for tox21molnet
1 parent 4102fe9 commit e0a794e

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import unittest
2+
from typing import List
3+
from unittest.mock import MagicMock, mock_open, patch
4+
5+
import torch
6+
7+
from chebai.preprocessing.datasets.tox21 import Tox21MolNet
8+
from chebai.preprocessing.reader import ChemDataReader
9+
from tests.unit.mock_data.tox_mock_data import Tox21MolNetMockData
10+
11+
12+
class TestTox21MolNet(unittest.TestCase):
13+
@classmethod
14+
@patch("os.makedirs", return_value=None)
15+
def setUpClass(cls, mock_makedirs: MagicMock) -> None:
16+
"""
17+
Initialize a Tox21MolNet instance for testing.
18+
19+
Args:
20+
mock_makedirs (MagicMock): Mocked `os.makedirs` function.
21+
"""
22+
Tox21MolNet.READER = ChemDataReader
23+
cls.data_module = Tox21MolNet()
24+
25+
@patch(
26+
"builtins.open",
27+
new_callable=mock_open,
28+
read_data=Tox21MolNetMockData.get_raw_data(),
29+
)
30+
def test_load_data_from_file(self, mock_open_file: mock_open) -> None:
31+
"""
32+
Test the `_load_data_from_file` method for correct output.
33+
34+
Args:
35+
mock_open_file (mock_open): Mocked open function to simulate file reading.
36+
"""
37+
actual_data = self.data_module._load_data_from_file("fake/file/path.csv")
38+
39+
first_instance = next(actual_data)
40+
41+
# Check for required keys
42+
required_keys = ["features", "labels", "ident"]
43+
for key in required_keys:
44+
self.assertIn(
45+
key, first_instance, f"'{key}' key is missing in the output data."
46+
)
47+
48+
self.assertTrue(
49+
all(isinstance(feature, int) for feature in first_instance["features"]),
50+
"Not all elements in 'features' are integers.",
51+
)
52+
53+
# Check that 'features' can be converted to a tensor
54+
features = first_instance["features"]
55+
try:
56+
tensor_features = torch.tensor(features)
57+
self.assertTrue(
58+
tensor_features.ndim > 0,
59+
"'features' should be convertible to a non-empty tensor.",
60+
)
61+
except Exception as e:
62+
self.fail(f"'features' cannot be converted to a tensor: {str(e)}")
63+
64+
@patch(
65+
"builtins.open",
66+
new_callable=mock_open,
67+
read_data=Tox21MolNetMockData.get_raw_data(),
68+
)
69+
@patch("torch.save")
70+
def test_setup_processed_simple_split(
71+
self,
72+
mock_torch_save: MagicMock,
73+
mock_open_file: mock_open,
74+
) -> None:
75+
"""
76+
Test the `setup_processed` method for basic data splitting and saving.
77+
78+
Args:
79+
mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes.
80+
mock_open_file (mock_open): Mocked `open` function to simulate file reading.
81+
"""
82+
self.data_module.setup_processed()
83+
84+
# Verify if torch.save was called for each split (train, test, validation)
85+
self.assertEqual(
86+
mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times."
87+
)
88+
call_args_list = mock_torch_save.call_args_list
89+
self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.")
90+
self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.")
91+
self.assertIn(
92+
"validation", call_args_list[2][0][1], "Missing 'validation' split."
93+
)
94+
95+
# Check for non-overlap between train, test, and validation splits
96+
test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]]
97+
train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]]
98+
validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]]
99+
100+
self.assertTrue(
101+
set(train_split).isdisjoint(test_split),
102+
"Overlap detected between the train and test splits.",
103+
)
104+
self.assertTrue(
105+
set(train_split).isdisjoint(validation_split),
106+
"Overlap detected between the train and validation splits.",
107+
)
108+
self.assertTrue(
109+
set(test_split).isdisjoint(validation_split),
110+
"Overlap detected between the test and validation splits.",
111+
)
112+
113+
@patch.object(
114+
Tox21MolNet,
115+
"_load_data_from_file",
116+
return_value=Tox21MolNetMockData.get_processed_grouped_data(),
117+
)
118+
@patch("torch.save")
119+
def test_setup_processed_with_group_split(
120+
self, mock_torch_save: MagicMock, mock_load_file: MagicMock
121+
) -> None:
122+
"""
123+
Test the `setup_processed` method for group-based splitting and saving.
124+
125+
Args:
126+
mock_torch_save (MagicMock): Mocked `torch.save` function to avoid actual file writes.
127+
mock_load_file (MagicMock): Mocked `_load_data_from_file` to provide custom data.
128+
"""
129+
self.data_module.train_split = 0.5
130+
self.data_module.setup_processed()
131+
132+
# Verify if torch.save was called for each split
133+
self.assertEqual(
134+
mock_torch_save.call_count, 3, "Expected torch.save to be called 3 times."
135+
)
136+
call_args_list = mock_torch_save.call_args_list
137+
self.assertIn("test", call_args_list[0][0][1], "Missing 'test' split.")
138+
self.assertIn("train", call_args_list[1][0][1], "Missing 'train' split.")
139+
self.assertIn(
140+
"validation", call_args_list[2][0][1], "Missing 'validation' split."
141+
)
142+
143+
# Check for non-overlap between train, test, and validation splits (based on 'ident')
144+
test_split: List[str] = [d["ident"] for d in call_args_list[0][0][0]]
145+
train_split: List[str] = [d["ident"] for d in call_args_list[1][0][0]]
146+
validation_split: List[str] = [d["ident"] for d in call_args_list[2][0][0]]
147+
148+
self.assertTrue(
149+
set(train_split).isdisjoint(test_split),
150+
"Overlap detected between the train and test splits (based on 'ident').",
151+
)
152+
self.assertTrue(
153+
set(train_split).isdisjoint(validation_split),
154+
"Overlap detected between the train and validation splits (based on 'ident').",
155+
)
156+
self.assertTrue(
157+
set(test_split).isdisjoint(validation_split),
158+
"Overlap detected between the test and validation splits (based on 'ident').",
159+
)
160+
161+
# Check for non-overlap between train, test, and validation splits (based on 'group')
162+
test_split_grp: List[str] = [d["group"] for d in call_args_list[0][0][0]]
163+
train_split_grp: List[str] = [d["group"] for d in call_args_list[1][0][0]]
164+
validation_split_grp: List[str] = [d["group"] for d in call_args_list[2][0][0]]
165+
166+
self.assertTrue(
167+
set(train_split_grp).isdisjoint(test_split_grp),
168+
"Overlap detected between the train and test splits (based on 'group').",
169+
)
170+
self.assertTrue(
171+
set(train_split_grp).isdisjoint(validation_split_grp),
172+
"Overlap detected between the train and validation splits (based on 'group').",
173+
)
174+
self.assertTrue(
175+
set(test_split_grp).isdisjoint(validation_split_grp),
176+
"Overlap detected between the test and validation splits (based on 'group').",
177+
)
178+
179+
180+
if __name__ == "__main__":
181+
unittest.main()

0 commit comments

Comments
 (0)