@@ -85,8 +85,10 @@ def test_get_test_split_valid(self) -> None:
8585 """
8686 Test splitting the dataset into train and test sets and verify balance and non-overlap.
8787 """
88- self .dataset .train_split = 0.5
88+ # self.dataset.train_split = 0.5
8989 # Test size will be 0.25 * 16 = 4
90+ self .dataset .test_split = 0.25
91+ self .dataset .validation_split = 0.25
9092 train_df , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
9193
9294 # Assert the correct number of rows in train and test sets
@@ -149,7 +151,9 @@ def test_get_train_val_splits_given_test(self) -> None:
149151 Test splitting the dataset into train and validation sets and verify balance and non-overlap.
150152 """
151153 self .dataset .use_inner_cross_validation = False
152- self .dataset .train_split = 0.5
154+ # self.dataset.train_split = 0.5
155+ self .dataset .test_split = 0.25
156+ self .dataset .validation_split = 0.25
153157 df_train_main , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
154158 train_df , val_df = self .dataset .get_train_val_splits_given_test (
155159 df_train_main , test_df , seed = 42
@@ -220,7 +224,9 @@ def test_get_test_split_stratification(self) -> None:
220224 """
221225 Test that the split into train and test sets maintains the stratification of labels.
222226 """
223- self .dataset .train_split = 0.5
227+ # self.dataset.train_split = 0.5
228+ self .dataset .test_split = 0.25
229+ self .dataset .validation_split = 0.25
224230 train_df , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
225231
226232 number_of_labels = len (self .data_df ["labels" ][0 ])
@@ -288,7 +294,10 @@ def test_get_train_val_splits_given_test_stratification(self) -> None:
288294 Test that the split into train and validation sets maintains the stratification of labels.
289295 """
290296 self .dataset .use_inner_cross_validation = False
291- self .dataset .train_split = 0.5
297+ # self.dataset.train_split = 0.5
298+ self .dataset .test_split = 0.25
299+ self .dataset .validation_split = 0.25
300+
292301 df_train_main , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
293302 train_df , val_df = self .dataset .get_train_val_splits_given_test (
294303 df_train_main , test_df , seed = 42
0 commit comments