Skip to content

Commit 5b85ace

Browse files
committed
Merge branch 'dev' into feature/pyproject.toml
2 parents c3bbdfa + 26329cb commit 5b85ace

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs):
2525
"""
2626
self.init_args = args
2727
self.init_kwargs = kwargs
28-
super().__init__(*args, **kwargs)
28+
super().__init__(*args, **kwargs, deterministic=True)
2929
# instantiation custom logger connector
3030
self._logger_connector.on_trainer_init(self.logger, 1)
3131
# log additional hyperparameters to wandb

tests/unit/dataset_classes/testDynamicDataset.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ def test_get_test_split_valid(self) -> None:
8585
"""
8686
Test splitting the dataset into train and test sets and verify balance and non-overlap.
8787
"""
88-
self.dataset.train_split = 0.5
88+
# self.dataset.train_split = 0.5
8989
# Test size will be 0.25 * 16 = 4
90+
self.dataset.test_split = 0.25
91+
self.dataset.validation_split = 0.25
9092
train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0)
9193

9294
# Assert the correct number of rows in train and test sets
@@ -149,7 +151,9 @@ def test_get_train_val_splits_given_test(self) -> None:
149151
Test splitting the dataset into train and validation sets and verify balance and non-overlap.
150152
"""
151153
self.dataset.use_inner_cross_validation = False
152-
self.dataset.train_split = 0.5
154+
# self.dataset.train_split = 0.5
155+
self.dataset.test_split = 0.25
156+
self.dataset.validation_split = 0.25
153157
df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0)
154158
train_df, val_df = self.dataset.get_train_val_splits_given_test(
155159
df_train_main, test_df, seed=42
@@ -220,7 +224,9 @@ def test_get_test_split_stratification(self) -> None:
220224
"""
221225
Test that the split into train and test sets maintains the stratification of labels.
222226
"""
223-
self.dataset.train_split = 0.5
227+
# self.dataset.train_split = 0.5
228+
self.dataset.test_split = 0.25
229+
self.dataset.validation_split = 0.25
224230
train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0)
225231

226232
number_of_labels = len(self.data_df["labels"][0])
@@ -288,7 +294,10 @@ def test_get_train_val_splits_given_test_stratification(self) -> None:
288294
Test that the split into train and validation sets maintains the stratification of labels.
289295
"""
290296
self.dataset.use_inner_cross_validation = False
291-
self.dataset.train_split = 0.5
297+
# self.dataset.train_split = 0.5
298+
self.dataset.test_split = 0.25
299+
self.dataset.validation_split = 0.25
300+
292301
df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0)
293302
train_df, val_df = self.dataset.get_train_val_splits_given_test(
294303
df_train_main, test_df, seed=42

0 commit comments

Comments
 (0)