Skip to content

Commit a63c010

Browse files
committed
DynamicDataset: check split stratification
1 parent 72dd50f commit a63c010

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

tests/unit/dataset_classes/testDynamicDataset.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,142 @@ def test_get_train_val_splits_given_test_consistency(self) -> None:
216216
obj="Validation sets should be identical for the same seed.",
217217
)
218218

219+
def test_get_test_split_stratification(self) -> None:
220+
"""
221+
Test that the split into train and test sets maintains the stratification of labels.
222+
"""
223+
self.dataset.train_split = 0.5
224+
train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0)
225+
226+
number_of_labels = len(self.data_df["labels"][0])
227+
228+
# Check the label distribution in the original dataset
229+
original_pos_count, original_neg_count = (
230+
self.get_positive_negative_labels_counts(self.data_df)
231+
)
232+
total_count = len(self.data_df) * number_of_labels
233+
234+
# Calculate the expected proportions
235+
original_pos_proportion = original_pos_count / total_count
236+
original_neg_proportion = original_neg_count / total_count
237+
238+
# Check the label distribution in the train set
239+
train_pos_count, train_neg_count = self.get_positive_negative_labels_counts(
240+
train_df
241+
)
242+
train_total_count = len(train_df) * number_of_labels
243+
244+
# Calculate the train set proportions
245+
train_pos_proportion = train_pos_count / train_total_count
246+
train_neg_proportion = train_neg_count / train_total_count
247+
248+
# Assert that the proportions are similar to the original dataset
249+
self.assertAlmostEqual(
250+
train_pos_proportion,
251+
original_pos_proportion,
252+
places=1,
253+
msg="Train set labels should maintain original positive label proportion.",
254+
)
255+
self.assertAlmostEqual(
256+
train_neg_proportion,
257+
original_neg_proportion,
258+
places=1,
259+
msg="Train set labels should maintain original negative label proportion.",
260+
)
261+
262+
# Check the label distribution in the test set
263+
test_pos_count, test_neg_count = self.get_positive_negative_labels_counts(
264+
test_df
265+
)
266+
test_total_count = len(test_df) * number_of_labels
267+
268+
# Calculate the test set proportions
269+
test_pos_proportion = test_pos_count / test_total_count
270+
test_neg_proportion = test_neg_count / test_total_count
271+
272+
# Assert that the proportions are similar to the original dataset
273+
self.assertAlmostEqual(
274+
test_pos_proportion,
275+
original_pos_proportion,
276+
places=1,
277+
msg="Test set labels should maintain original positive label proportion.",
278+
)
279+
self.assertAlmostEqual(
280+
test_neg_proportion,
281+
original_neg_proportion,
282+
places=1,
283+
msg="Test set labels should maintain original negative label proportion.",
284+
)
285+
286+
def test_get_train_val_splits_given_test_stratification(self) -> None:
287+
"""
288+
Test that the split into train and validation sets maintains the stratification of labels.
289+
"""
290+
self.dataset.use_inner_cross_validation = False
291+
self.dataset.train_split = 0.5
292+
df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0)
293+
train_df, val_df = self.dataset.get_train_val_splits_given_test(
294+
df_train_main, test_df, seed=42
295+
)
296+
297+
number_of_labels = len(self.data_df["labels"][0])
298+
299+
# Check the label distribution in the original dataset
300+
original_pos_count, original_neg_count = (
301+
self.get_positive_negative_labels_counts(self.data_df)
302+
)
303+
total_count = len(self.data_df) * number_of_labels
304+
305+
# Calculate the expected proportions
306+
original_pos_proportion = original_pos_count / total_count
307+
original_neg_proportion = original_neg_count / total_count
308+
309+
# Check the label distribution in the train set
310+
train_pos_count, train_neg_count = self.get_positive_negative_labels_counts(
311+
train_df
312+
)
313+
train_total_count = len(train_df) * number_of_labels
314+
315+
# Calculate the train set proportions
316+
train_pos_proportion = train_pos_count / train_total_count
317+
train_neg_proportion = train_neg_count / train_total_count
318+
319+
# Assert that the proportions are similar to the original dataset
320+
self.assertAlmostEqual(
321+
train_pos_proportion,
322+
original_pos_proportion,
323+
places=1,
324+
msg="Train set labels should maintain original positive label proportion.",
325+
)
326+
self.assertAlmostEqual(
327+
train_neg_proportion,
328+
original_neg_proportion,
329+
places=1,
330+
msg="Train set labels should maintain original negative label proportion.",
331+
)
332+
333+
# Check the label distribution in the validation set
334+
val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df)
335+
val_total_count = len(val_df) * number_of_labels
336+
337+
# Calculate the validation set proportions
338+
val_pos_proportion = val_pos_count / val_total_count
339+
val_neg_proportion = val_neg_count / val_total_count
340+
341+
# Assert that the proportions are similar to the original dataset
342+
self.assertAlmostEqual(
343+
val_pos_proportion,
344+
original_pos_proportion,
345+
places=1,
346+
msg="Validation set labels should maintain original positive label proportion.",
347+
)
348+
self.assertAlmostEqual(
349+
val_neg_proportion,
350+
original_neg_proportion,
351+
places=1,
352+
msg="Validation set labels should maintain original negative label proportion.",
353+
)
354+
219355
@staticmethod
220356
def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]:
221357
"""

0 commit comments

Comments
 (0)