11import hashlib
22import unittest
3+ from typing import Any , List , Tuple
34
45import numpy as np
56import pandas as pd
89
910
1011class TestChebiDynamicDataSplits (unittest .TestCase ):
11- """Test dynamic splits implementation's consistency"""
12+ """
13+ Test dynamic splits implementation's consistency for ChEBIOver50 dataset.
14+
15+ Attributes:
16+ chebi_50_v231 (ChEBIOver50): Instance of ChEBIOver50 with ChEBI version 231.
17+ chebi_50_v231_vt200 (ChEBIOver50): Instance of ChEBIOver50 with ChEBI version 231 and train version 200.
18+ """
1219
1320 @classmethod
14- def setUpClass (cls ):
21+ def setUpClass (cls ) -> None :
22+ """
23+ Set up class method to initialize instances of ChEBIOver50 and generate data.
24+ """
1525 cls .chebi_50_v231 = ChEBIOver50 (chebi_version = 231 )
1626 cls .chebi_50_v231_vt200 = ChEBIOver50 (
1727 chebi_version = 231 , chebi_version_train = 200
1828 )
1929 cls ._generate_chebi_class_data (cls .chebi_50_v231 )
2030 cls ._generate_chebi_class_data (cls .chebi_50_v231_vt200 )
2131
22- def testDynamicDataSplitsConsistency (self ):
23- """Test Dynamic Data Splits consistency across every run"""
24-
32+ def testDynamicDataSplitsConsistency (self ) -> None :
33+ """
34+ Test Dynamic Data Splits consistency across multiple runs.
35+ """
2536 # Dynamic Data Splits in Run 1
2637 train_hash_1 , val_hash_1 , test_hash_1 = self ._get_hashed_splits ()
2738
@@ -34,9 +45,10 @@ def testDynamicDataSplitsConsistency(self):
3445 self .assertEqual (val_hash_1 , val_hash_2 , "Validation data hashes do not match." )
3546 self .assertEqual (test_hash_1 , test_hash_2 , "Test data hashes do not match." )
3647
37- def test_same_ids_and_in_test_sets (self ):
38- """Check if test sets of both classes have same IDs"""
39-
48+ def test_same_ids_and_in_test_sets (self ) -> None :
49+ """
50+ Check if test sets of both classes have the same IDs.
51+ """
4052 v231_ids = set (self .chebi_50_v231 .dynamic_split_dfs ["test" ]["ident" ])
4153 v231_vt200_ids = set (
4254 self .chebi_50_v231_vt200 .dynamic_split_dfs ["test" ]["ident" ]
@@ -46,9 +58,10 @@ def test_same_ids_and_in_test_sets(self):
4658 v231_ids , v231_vt200_ids , "Test sets do not have the same IDs."
4759 )
4860
49- def test_labels_vector_size_in_test_sets (self ):
50- """Check if test sets of both classes have different size/shape of labels"""
51-
61+ def test_labels_vector_size_in_test_sets (self ) -> None :
62+ """
63+ Check if test sets of both classes have different sizes/shapes of labels.
64+ """
5265 v231_labels_shape = len (
5366 self .chebi_50_v231 .dynamic_split_dfs ["test" ]["labels" ].iloc [0 ]
5467 )
@@ -59,11 +72,13 @@ def test_labels_vector_size_in_test_sets(self):
5972 self .assertEqual (
6073 v231_labels_shape ,
6174 v231_vt200_label_shape ,
62- "Test sets have the different size of labels" ,
75+ "Test sets have different sizes of labels" ,
6376 )
6477
65- def test_no_overlaps_in_chebi_v231_vt200 (self ):
66- """Test the overlaps for the ChEBIOver50(chebi_version=231, chebi_version_train=200)"""
78+ def test_no_overlaps_in_chebi_v231_vt200 (self ) -> None :
79+ """
80+ Test the overlaps for the ChEBIOver50(chebi_version=231, chebi_version_train=200) dataset.
81+ """
6782 train_set = self .chebi_50_v231_vt200 .dynamic_split_dfs ["train" ]
6883 val_set = self .chebi_50_v231_vt200 .dynamic_split_dfs ["validation" ]
6984 test_set = self .chebi_50_v231_vt200 .dynamic_split_dfs ["test" ]
@@ -72,7 +87,7 @@ def test_no_overlaps_in_chebi_v231_vt200(self):
7287 val_set_ids = val_set ["ident" ].tolist ()
7388 test_set_ids = test_set ["ident" ].tolist ()
7489
75- # ----- Get the overlap between data splits based on IDs -----
90+ # Get the overlap between data splits based on IDs
7691 self .overlaps_train_val_ids = self .get_overlaps (train_set_ids , val_set_ids )
7792 self .overlaps_train_test_ids = self .get_overlaps (train_set_ids , test_set_ids )
7893 self .overlaps_val_test_ids = self .get_overlaps (val_set_ids , test_set_ids )
@@ -93,10 +108,13 @@ def test_no_overlaps_in_chebi_v231_vt200(self):
93108 "Duplicate entities present in Validation and Test set based on IDs" ,
94109 )
95110
96- def _get_hashed_splits (self ):
97- """Returns hashed dynamic data splits"""
111+ def _get_hashed_splits (self ) -> Tuple [str , str , str ]:
112+ """
113+ Returns hashed dynamic data splits.
98114
99- # Get the raw/processed data if missing
115+ Returns:
116+ Tuple[str, str, str]: Hashes for train, validation, and test data splits.
117+ """
100118 chebi_class_obj = self .chebi_50_v231
101119
102120 # Get dynamic splits from class variables
@@ -112,16 +130,32 @@ def _get_hashed_splits(self):
112130 return train_hash , val_hash , test_hash
113131
114132 @staticmethod
115- def compute_hash (data ):
116- """Returns hash for the given data partition"""
133+ def compute_hash (data : pd .DataFrame ) -> str :
134+ """
135+ Returns hash for the given data partition.
136+
137+ Args:
138+ data (pd.DataFrame): DataFrame containing data to be hashed.
139+
140+ Returns:
141+ str: Hash computed for the DataFrame.
142+ """
117143 data_for_hashing = data .map (TestChebiDynamicDataSplits .convert_to_hashable )
118144 return hashlib .md5 (
119145 pd .util .hash_pandas_object (data_for_hashing , index = True ).values
120146 ).hexdigest ()
121147
122148 @staticmethod
123- def convert_to_hashable (item ):
124- """To Convert lists and numpy arrays within the DataFrame to tuples for hashing"""
149+ def convert_to_hashable (item : Any ) -> Any :
150+ """
151+ Convert lists and numpy arrays within the DataFrame to tuples for hashing.
152+
153+ Args:
154+ item (Any): Item to convert to a hashable form.
155+
156+ Returns:
157+ Any: Hashable representation of the input item.
158+ """
125159 if isinstance (item , list ):
126160 return tuple (item )
127161 elif isinstance (item , np .ndarray ):
@@ -130,13 +164,28 @@ def convert_to_hashable(item):
130164 return item
131165
132166 @staticmethod
133- def _generate_chebi_class_data (chebi_class_obj ):
134- # Get the raw/processed data if missing
167+ def _generate_chebi_class_data (chebi_class_obj : ChEBIOver50 ) -> None :
168+ """
169+ Generate ChEBI class data if not already generated.
170+
171+ Args:
172+ chebi_class_obj (ChEBIOver50): Instance of ChEBIOver50 class.
173+ """
135174 chebi_class_obj .prepare_data ()
136175 chebi_class_obj .setup ()
137176
138177 @staticmethod
139- def get_overlaps (list_1 , list_2 ):
178+ def get_overlaps (list_1 : List [Any ], list_2 : List [Any ]) -> List [Any ]:
179+ """
180+ Get overlaps between two lists.
181+
182+ Args:
183+ list_1 (List[Any]): First list.
184+ list_2 (List[Any]): Second list.
185+
186+ Returns:
187+ List[Any]: List of elements present in both lists.
188+ """
140189 overlap = []
141190 for element in list_1 :
142191 if element in list_2 :
0 commit comments