@@ -104,6 +104,72 @@ def test_extract_class_hierarchy_with_bottom_cls(
104104 f"The graph nodes do not match the expected nodes for top class { self .chebi_extractor .top_class_id } hierarchy." ,
105105 )
106106
107+ @patch ("pandas.DataFrame.to_csv" )
108+ @patch ("pandas.read_pickle" )
109+ @patch .object (ChEBIOverXPartial , "_get_data_size" , return_value = 4.0 )
110+ @patch ("torch.load" )
111+ @patch (
112+ "builtins.open" ,
113+ new_callable = mock_open ,
114+ read_data = ChebiMockOntology .get_raw_data (),
115+ )
116+ def test_single_label_data_split (
117+ self , mock_open , mock_load , mock_get_data_size , mock_read_pickle , mock_to_csv
118+ ) -> None :
119+ """
120+ Test the single-label data splitting functionality of the ChebiExtractor class.
121+
122+ This test mocks several key methods (file operations, torch loading, and pandas functions)
123+ to ensure that the class hierarchy is properly extracted, data is processed into a raw dataset,
124+ and the data splitting logic works as intended without actual file I/O.
125+
126+ It also verifies that there is no overlap between training, validation, and test sets.
127+ """
128+ self .chebi_extractor .top_class_id = 11111
129+ self .chebi_extractor .THRESHOLD = 3
130+ self .chebi_extractor .chebi_version_train = None
131+
132+ graph : nx .DiGraph = self .chebi_extractor ._extract_class_hierarchy ("fake_path" )
133+ data_df = self .chebi_extractor ._graph_to_raw_dataset (graph )
134+
135+ mock_read_pickle .return_value = data_df
136+ data_pt = self .chebi_extractor ._load_data_from_file ("fake/path" )
137+
138+ # Verify that the data contains only 1 label
139+ self .assertEqual (len (data_pt [0 ]["labels" ]), 1 )
140+
141+ mock_load .return_value = data_pt
142+
143+ # Retrieve the data splits (train, validation, and test)
144+ train_split = self .chebi_extractor .dynamic_split_dfs ["train" ]
145+ validation_split = self .chebi_extractor .dynamic_split_dfs ["validation" ]
146+ test_split = self .chebi_extractor .dynamic_split_dfs ["test" ]
147+
148+ train_idents = set (train_split ["ident" ])
149+ val_idents = set (validation_split ["ident" ])
150+ test_idents = set (test_split ["ident" ])
151+
152+ # Ensure there is no overlap between train and test sets
153+ self .assertEqual (
154+ len (train_idents .intersection (test_idents )),
155+ 0 ,
156+ "Train and test sets should not overlap." ,
157+ )
158+
159+ # Ensure there is no overlap between validation and test sets
160+ self .assertEqual (
161+ len (val_idents .intersection (test_idents )),
162+ 0 ,
163+ "Validation and test sets should not overlap." ,
164+ )
165+
166+ # Ensure there is no overlap between train and validation sets
167+ self .assertEqual (
168+ len (train_idents .intersection (val_idents )),
169+ 0 ,
170+ "Train and validation sets should not overlap." ,
171+ )
172+
107173
108174if __name__ == "__main__" :
109175 unittest .main ()
0 commit comments