Skip to content

Commit 19b194a

Browse files
committed
fix - if only one class surpass given selection threshold
- #54 (comment)
1 parent 96d2097 commit 19b194a

File tree

1 file changed

+21
-8
lines changed
  • chebai/preprocessing/datasets

1 file changed

+21
-8
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from lightning.pytorch.core.datamodule import LightningDataModule
1616
from lightning_utilities.core.rank_zero import rank_zero_info
17+
from sklearn.model_selection import StratifiedShuffleSplit
1718
from torch.utils.data import DataLoader
1819

1920
from chebai.preprocessing import reader as dr
@@ -929,11 +930,17 @@ def get_test_split(
929930
labels_list = df["labels"].tolist()
930931

931932
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
932-
msss = MultilabelStratifiedShuffleSplit(
933-
n_splits=1, test_size=test_size, random_state=seed
934-
)
935933

936-
train_indices, test_indices = next(msss.split(labels_list, labels_list))
934+
if len(labels_list[0]) > 1:
935+
splitter = MultilabelStratifiedShuffleSplit(
936+
n_splits=1, test_size=test_size, random_state=seed
937+
)
938+
else:
939+
splitter = StratifiedShuffleSplit(
940+
n_splits=1, test_size=test_size, random_state=seed
941+
)
942+
943+
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
937944

938945
df_train = df.iloc[train_indices]
939946
df_test = df.iloc[test_indices]
@@ -985,12 +992,18 @@ def get_train_val_splits_given_test(
985992

986993
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
987994
test_size = ((1 - self.train_split) ** 2) / self.train_split
988-
msss = MultilabelStratifiedShuffleSplit(
989-
n_splits=1, test_size=test_size, random_state=seed
990-
)
995+
996+
if len(labels_list_trainval[0]) > 1:
997+
splitter = MultilabelStratifiedShuffleSplit(
998+
n_splits=1, test_size=test_size, random_state=seed
999+
)
1000+
else:
1001+
splitter = StratifiedShuffleSplit(
1002+
n_splits=1, test_size=test_size, random_state=seed
1003+
)
9911004

9921005
train_indices, validation_indices = next(
993-
msss.split(labels_list_trainval, labels_list_trainval)
1006+
splitter.split(labels_list_trainval, labels_list_trainval)
9941007
)
9951008

9961009
df_validation = df_trainval.iloc[validation_indices]

0 commit comments

Comments
 (0)