@@ -216,6 +216,142 @@ def test_get_train_val_splits_given_test_consistency(self) -> None:
216216 obj = "Validation sets should be identical for the same seed." ,
217217 )
218218
219+ def test_get_test_split_stratification (self ) -> None :
220+ """
221+ Test that the split into train and test sets maintains the stratification of labels.
222+ """
223+ self .dataset .train_split = 0.5
224+ train_df , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
225+
226+ number_of_labels = len (self .data_df ["labels" ][0 ])
227+
228+ # Check the label distribution in the original dataset
229+ original_pos_count , original_neg_count = (
230+ self .get_positive_negative_labels_counts (self .data_df )
231+ )
232+ total_count = len (self .data_df ) * number_of_labels
233+
234+ # Calculate the expected proportions
235+ original_pos_proportion = original_pos_count / total_count
236+ original_neg_proportion = original_neg_count / total_count
237+
238+ # Check the label distribution in the train set
239+ train_pos_count , train_neg_count = self .get_positive_negative_labels_counts (
240+ train_df
241+ )
242+ train_total_count = len (train_df ) * number_of_labels
243+
244+ # Calculate the train set proportions
245+ train_pos_proportion = train_pos_count / train_total_count
246+ train_neg_proportion = train_neg_count / train_total_count
247+
248+ # Assert that the proportions are similar to the original dataset
249+ self .assertAlmostEqual (
250+ train_pos_proportion ,
251+ original_pos_proportion ,
252+ places = 1 ,
253+ msg = "Train set labels should maintain original positive label proportion." ,
254+ )
255+ self .assertAlmostEqual (
256+ train_neg_proportion ,
257+ original_neg_proportion ,
258+ places = 1 ,
259+ msg = "Train set labels should maintain original negative label proportion." ,
260+ )
261+
262+ # Check the label distribution in the test set
263+ test_pos_count , test_neg_count = self .get_positive_negative_labels_counts (
264+ test_df
265+ )
266+ test_total_count = len (test_df ) * number_of_labels
267+
268+ # Calculate the test set proportions
269+ test_pos_proportion = test_pos_count / test_total_count
270+ test_neg_proportion = test_neg_count / test_total_count
271+
272+ # Assert that the proportions are similar to the original dataset
273+ self .assertAlmostEqual (
274+ test_pos_proportion ,
275+ original_pos_proportion ,
276+ places = 1 ,
277+ msg = "Test set labels should maintain original positive label proportion." ,
278+ )
279+ self .assertAlmostEqual (
280+ test_neg_proportion ,
281+ original_neg_proportion ,
282+ places = 1 ,
283+ msg = "Test set labels should maintain original negative label proportion." ,
284+ )
285+
286+ def test_get_train_val_splits_given_test_stratification (self ) -> None :
287+ """
288+ Test that the split into train and validation sets maintains the stratification of labels.
289+ """
290+ self .dataset .use_inner_cross_validation = False
291+ self .dataset .train_split = 0.5
292+ df_train_main , test_df = self .dataset .get_test_split (self .data_df , seed = 0 )
293+ train_df , val_df = self .dataset .get_train_val_splits_given_test (
294+ df_train_main , test_df , seed = 42
295+ )
296+
297+ number_of_labels = len (self .data_df ["labels" ][0 ])
298+
299+ # Check the label distribution in the original dataset
300+ original_pos_count , original_neg_count = (
301+ self .get_positive_negative_labels_counts (self .data_df )
302+ )
303+ total_count = len (self .data_df ) * number_of_labels
304+
305+ # Calculate the expected proportions
306+ original_pos_proportion = original_pos_count / total_count
307+ original_neg_proportion = original_neg_count / total_count
308+
309+ # Check the label distribution in the train set
310+ train_pos_count , train_neg_count = self .get_positive_negative_labels_counts (
311+ train_df
312+ )
313+ train_total_count = len (train_df ) * number_of_labels
314+
315+ # Calculate the train set proportions
316+ train_pos_proportion = train_pos_count / train_total_count
317+ train_neg_proportion = train_neg_count / train_total_count
318+
319+ # Assert that the proportions are similar to the original dataset
320+ self .assertAlmostEqual (
321+ train_pos_proportion ,
322+ original_pos_proportion ,
323+ places = 1 ,
324+ msg = "Train set labels should maintain original positive label proportion." ,
325+ )
326+ self .assertAlmostEqual (
327+ train_neg_proportion ,
328+ original_neg_proportion ,
329+ places = 1 ,
330+ msg = "Train set labels should maintain original negative label proportion." ,
331+ )
332+
333+ # Check the label distribution in the validation set
334+ val_pos_count , val_neg_count = self .get_positive_negative_labels_counts (val_df )
335+ val_total_count = len (val_df ) * number_of_labels
336+
337+ # Calculate the validation set proportions
338+ val_pos_proportion = val_pos_count / val_total_count
339+ val_neg_proportion = val_neg_count / val_total_count
340+
341+ # Assert that the proportions are similar to the original dataset
342+ self .assertAlmostEqual (
343+ val_pos_proportion ,
344+ original_pos_proportion ,
345+ places = 1 ,
346+ msg = "Validation set labels should maintain original positive label proportion." ,
347+ )
348+ self .assertAlmostEqual (
349+ val_neg_proportion ,
350+ original_neg_proportion ,
351+ places = 1 ,
352+ msg = "Validation set labels should maintain original negative label proportion." ,
353+ )
354+
219355 @staticmethod
220356 def get_positive_negative_labels_counts (df : pd .DataFrame ) -> Tuple [int , int ]:
221357 """
0 commit comments