Skip to content

Commit adedc09

Browse files
committed
test single label split scenario implemented in #54
1 parent b479d5a commit adedc09

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/unit/dataset_classes/testChebiOverXPartial.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,72 @@ def test_extract_class_hierarchy_with_bottom_cls(
104104
f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.",
105105
)
106106

107+
@patch("pandas.DataFrame.to_csv")
108+
@patch("pandas.read_pickle")
109+
@patch.object(ChEBIOverXPartial, "_get_data_size", return_value=4.0)
110+
@patch("torch.load")
111+
@patch(
112+
"builtins.open",
113+
new_callable=mock_open,
114+
read_data=ChebiMockOntology.get_raw_data(),
115+
)
116+
def test_single_label_data_split(
117+
self, mock_open, mock_load, mock_get_data_size, mock_read_pickle, mock_to_csv
118+
) -> None:
119+
"""
120+
Test the single-label data splitting functionality of the ChebiExtractor class.
121+
122+
This test mocks several key methods (file operations, torch loading, and pandas functions)
123+
to ensure that the class hierarchy is properly extracted, data is processed into a raw dataset,
124+
and the data splitting logic works as intended without actual file I/O.
125+
126+
It also verifies that there is no overlap between training, validation, and test sets.
127+
"""
128+
self.chebi_extractor.top_class_id = 11111
129+
self.chebi_extractor.THRESHOLD = 3
130+
self.chebi_extractor.chebi_version_train = None
131+
132+
graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path")
133+
data_df = self.chebi_extractor._graph_to_raw_dataset(graph)
134+
135+
mock_read_pickle.return_value = data_df
136+
data_pt = self.chebi_extractor._load_data_from_file("fake/path")
137+
138+
# Verify that the data contains only 1 label
139+
self.assertEqual(len(data_pt[0]["labels"]), 1)
140+
141+
mock_load.return_value = data_pt
142+
143+
# Retrieve the data splits (train, validation, and test)
144+
train_split = self.chebi_extractor.dynamic_split_dfs["train"]
145+
validation_split = self.chebi_extractor.dynamic_split_dfs["validation"]
146+
test_split = self.chebi_extractor.dynamic_split_dfs["test"]
147+
148+
train_idents = set(train_split["ident"])
149+
val_idents = set(validation_split["ident"])
150+
test_idents = set(test_split["ident"])
151+
152+
# Ensure there is no overlap between train and test sets
153+
self.assertEqual(
154+
len(train_idents.intersection(test_idents)),
155+
0,
156+
"Train and test sets should not overlap.",
157+
)
158+
159+
# Ensure there is no overlap between validation and test sets
160+
self.assertEqual(
161+
len(val_idents.intersection(test_idents)),
162+
0,
163+
"Validation and test sets should not overlap.",
164+
)
165+
166+
# Ensure there is no overlap between train and validation sets
167+
self.assertEqual(
168+
len(train_idents.intersection(val_idents)),
169+
0,
170+
"Train and validation sets should not overlap.",
171+
)
172+
107173

108174
if __name__ == "__main__":
109175
unittest.main()

0 commit comments

Comments
 (0)