Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion drevalpy/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,34 @@ def _split_early_stopping_data(

:param validation_dataset: input validation dataset
:param test_mode: LPO, LCO, LTO, LDO
:raises ValueError: if test_mode is not one of the expected values
:returns: the resulting validation and early stopping datasets
"""
validation_dataset.shuffle(random_state=42)

# Determine the number of splits b (default 4,
# but can be less if there are not enough groups)
if test_mode == "LTO":
tissues = validation_dataset.tissue
if tissues is None:
raise ValueError("Tissue information is required for LTO.")
n_splits = min(4, len(np.unique(tissues)))
elif test_mode == "LCO":
n_splits = min(4, len(np.unique(validation_dataset.cell_line_ids)))
elif test_mode == "LDO":
n_splits = min(4, len(np.unique(validation_dataset.drug_ids)))
else:
n_splits = 4

cv_v = validation_dataset.split_dataset(
n_cv_splits=4,
n_cv_splits=n_splits,
mode=test_mode,
split_validation=False,
split_early_stopping=False,
random_state=42,
)
# take the first fold of a 4 cv as the split i.e. 3/4 for validation and 1/4 for early stopping
# when n_groups is less than 4, we splits the validation dataset into 2/3 and 1/3 or 1/2 and 1/2
validation_dataset = cv_v[0]["train"]
early_stopping_dataset = cv_v[0]["test"]
return validation_dataset, early_stopping_dataset
Expand Down