Skip to content

Commit d18dd7a

Browse files
committed
Merge branch 'dev' into fix/save_out_dim_to_checkpoint
2 parents ea28280 + 26329cb commit d18dd7a

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/unit/dataset_classes/testDynamicDataset.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ def test_get_test_split_valid(self) -> None:
8585
"""
8686
Test splitting the dataset into train and test sets and verify balance and non-overlap.
8787
"""
88-
self.dataset.train_split = 0.5
88+
# self.dataset.train_split = 0.5
8989
# Test size will be 0.25 * 16 = 4
90+
self.dataset.test_split = 0.25
91+
self.dataset.validation_split = 0.25
9092
train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0)
9193

9294
# Assert the correct number of rows in train and test sets
@@ -149,7 +151,9 @@ def test_get_train_val_splits_given_test(self) -> None:
149151
Test splitting the dataset into train and validation sets and verify balance and non-overlap.
150152
"""
151153
self.dataset.use_inner_cross_validation = False
152-
self.dataset.train_split = 0.5
154+
# self.dataset.train_split = 0.5
155+
self.dataset.test_split = 0.25
156+
self.dataset.validation_split = 0.25
153157
df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0)
154158
train_df, val_df = self.dataset.get_train_val_splits_given_test(
155159
df_train_main, test_df, seed=42
@@ -220,7 +224,9 @@ def test_get_test_split_stratification(self) -> None:
220224
"""
221225
Test that the split into train and test sets maintains the stratification of labels.
222226
"""
223-
self.dataset.train_split = 0.5
227+
# self.dataset.train_split = 0.5
228+
self.dataset.test_split = 0.25
229+
self.dataset.validation_split = 0.25
224230
train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0)
225231

226232
number_of_labels = len(self.data_df["labels"][0])
@@ -288,7 +294,10 @@ def test_get_train_val_splits_given_test_stratification(self) -> None:
288294
Test that the split into train and validation sets maintains the stratification of labels.
289295
"""
290296
self.dataset.use_inner_cross_validation = False
291-
self.dataset.train_split = 0.5
297+
# self.dataset.train_split = 0.5
298+
self.dataset.test_split = 0.25
299+
self.dataset.validation_split = 0.25
300+
292301
df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0)
293302
train_df, val_df = self.dataset.get_train_val_splits_given_test(
294303
df_train_main, test_df, seed=42

0 commit comments

Comments
 (0)