Skip to content

Commit 64efd5a

Browse files
committed
implement logic
1 parent 9900d43 commit 64efd5a

File tree

1 file changed

+25
-70
lines changed

1 file changed

+25
-70
lines changed

autointent/context/data_handler/_data_handler.py

Lines changed: 25 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
dataset: Dataset,
3535
scheme: ValidationScheme = "ho",
36-
split_train: bool = True,
36+
separate_nodes: bool = True,
3737
random_seed: int = 0,
3838
n_folds: int = 3,
3939
) -> None:
@@ -42,7 +42,7 @@ def __init__(
4242
4343
:param dataset: Training dataset.
4444
:param random_seed: Seed for random number generation.
45-
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
45+
:param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
4646
threshold search).
4747
"""
4848
set_seed(random_seed)
@@ -55,7 +55,7 @@ def __init__(
5555
self.n_folds = n_folds
5656

5757
if scheme == "ho":
58-
self._split_ho(split_train)
58+
self._split_ho(separate_nodes)
5959
elif scheme == "cv":
6060
self._split_cv()
6161

@@ -82,6 +82,15 @@ def multilabel(self) -> bool:
8282
"""
8383
return self.dataset.multilabel
8484

85+
def _choose_split(self, split_name: str, idx: int | None = None) -> str:
86+
if idx is not None:
87+
split = f"{split_name}_{idx}"
88+
if split not in self.dataset:
89+
split = split_name
90+
else:
91+
split = split_name
92+
return split
93+
8594
def train_utterances(self, idx: int | None = None) -> list[str]:
8695
"""
8796
Retrieve training utterances from the dataset.
@@ -93,7 +102,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]:
93102
:param idx: Optional index for a specific training split.
94103
:return: List of training utterances.
95104
"""
96-
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
105+
split = self._choose_split(Split.TRAIN, idx)
97106
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
98107

99108
def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
@@ -107,7 +116,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
107116
:param idx: Optional index for a specific training split.
108117
:return: List of training labels.
109118
"""
110-
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
119+
split = self._choose_split(Split.TRAIN, idx)
111120
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
112121

113122
def train_labels_folded(self) -> list[ListOfGenericLabels]:
@@ -124,7 +133,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]:
124133
:param idx: Optional index for a specific validation split.
125134
:return: List of validation utterances.
126135
"""
127-
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
136+
split = self._choose_split(Split.VALIDATION, idx)
128137
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
129138

130139
def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
@@ -138,10 +147,10 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
138147
:param idx: Optional index for a specific validation split.
139148
:return: List of validation labels.
140149
"""
141-
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
150+
split = self._choose_split(Split.VALIDATION, idx)
142151
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
143152

144-
def test_utterances(self, idx: int | None = None) -> list[str]:
153+
def test_utterances(self) -> list[str]:
145154
"""
146155
Retrieve test utterances from the dataset.
147156
@@ -152,10 +161,9 @@ def test_utterances(self, idx: int | None = None) -> list[str]:
152161
:param idx: Optional index for a specific test split.
153162
:return: List of test utterances.
154163
"""
155-
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
156-
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
164+
return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature])
157165

158-
def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
166+
def test_labels(self) -> ListOfGenericLabels:
159167
"""
160168
Retrieve test labels from the dataset.
161169
@@ -166,8 +174,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
166174
:param idx: Optional index for a specific test split.
167175
:return: List of test labels.
168176
"""
169-
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
170-
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
177+
return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature])
171178

172179
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
173180
if self.scheme == "ho":
@@ -186,27 +193,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
186193
train_labels = [lab for lab in train_labels if lab is not None]
187194
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
188195

189-
def _split_ho(self, split_train: bool) -> None:
196+
def _split_ho(self, separate_nodes: bool) -> None:
190197
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
191198

192-
if split_train and Split.TRAIN in self.dataset:
199+
if separate_nodes and Split.TRAIN in self.dataset:
193200
self._split_train()
194201

195-
if Split.TEST not in self.dataset:
196-
test_size = 0.1 if has_validation_split else 0.2
197-
self._split_test(test_size)
198-
199202
if not has_validation_split:
200203
self._split_validation_from_train()
201-
elif Split.VALIDATION in self.dataset:
202-
self._split_validation()
203204

204205
for split in self.dataset:
205-
n_classes_split = self.dataset.get_n_classes(split)
206-
if n_classes_split != self.n_classes:
206+
n_classes_in_split = self.dataset.get_n_classes(split)
207+
if n_classes_in_split != self.n_classes:
207208
message = (
208-
f"Number of classes in split '{split}' doesn't match initial number of classes "
209-
f"({n_classes_split} != {self.n_classes})"
209+
f"{n_classes_in_split=} for '{split=}' doesn't match initial number of classes ({self.n_classes})"
210210
)
211211
raise ValueError(message)
212212

@@ -225,30 +225,6 @@ def _split_train(self) -> None:
225225
)
226226
self.dataset.pop(Split.TRAIN)
227227

228-
def _split_validation(self) -> None:
229-
"""
230-
Split on two sets.
231-
232-
One is for scoring node optimizaton, one is for decision node.
233-
"""
234-
self.dataset[f"{Split.VALIDATION}_0"], self.dataset[f"{Split.VALIDATION}_1"] = split_dataset(
235-
self.dataset,
236-
split=Split.VALIDATION,
237-
test_size=0.5,
238-
random_seed=self.random_seed,
239-
allow_oos_in_train=False, # only val data for decision node should contain OOS
240-
)
241-
self.dataset.pop(Split.VALIDATION)
242-
243-
def _split_validation_from_test(self) -> None:
244-
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
245-
self.dataset,
246-
split=Split.TEST,
247-
test_size=0.5,
248-
random_seed=self.random_seed,
249-
allow_oos_in_train=True, # both test and validation splits can contain OOS
250-
)
251-
252228
def _split_cv(self) -> None:
253229
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
254230
if extra_splits:
@@ -290,27 +266,6 @@ def _split_validation_from_train(self) -> None:
290266
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
291267
)
292268

293-
def _split_test(self, test_size: float) -> None:
294-
"""Obtain test set from train."""
295-
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
296-
self.dataset,
297-
split=f"{Split.TRAIN}_0",
298-
test_size=test_size,
299-
random_seed=self.random_seed,
300-
)
301-
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
302-
self.dataset,
303-
split=f"{Split.TRAIN}_1",
304-
test_size=test_size,
305-
random_seed=self.random_seed,
306-
allow_oos_in_train=True,
307-
)
308-
self.dataset[Split.TEST] = concatenate_datasets(
309-
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
310-
)
311-
self.dataset.pop(f"{Split.TEST}_0")
312-
self.dataset.pop(f"{Split.TEST}_1")
313-
314269
def prepare_for_refit(self) -> None:
315270
if self.scheme == "ho":
316271
return

0 commit comments

Comments
 (0)