Skip to content

Commit 71d714f

Browse files
authored
Feat/configurable split sizes (#141)
* refactor data_handler constructor * finish refactoring data handler * fix codestyle * update tests
1 parent 6cb6fe5 commit 71d714f

File tree

6 files changed

+41
-49
lines changed

6 files changed

+41
-49
lines changed

autointent/configs/_optimization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel, Field, PositiveInt
66

7-
from autointent.custom_types import SamplerType, ValidationScheme
7+
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
88

99
from ._name import get_run_name
1010

@@ -16,8 +16,10 @@ class DataConfig(BaseModel):
1616
"""Hold-out or cross-validation."""
1717
n_folds: PositiveInt = 3
1818
"""Number of folds in cross-validation."""
19-
separate_nodes: bool = True
20-
"""Whether to use separate data for decision node."""
19+
validation_size: FloatFromZeroToOne = 0.2
20+
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
21+
separation_ratio: FloatFromZeroToOne | None = 0.5
22+
"""Set to float to prevent data leak between scoring and decision nodes."""
2123

2224

2325
class TaskConfig(BaseModel):

autointent/context/_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
6565
6666
:param dataset: Dataset.
6767
"""
68-
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump())
68+
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config)
6969

7070
def get_inference_config(self) -> dict[str, Any]:
7171
"""

autointent/context/data_handler/_data_handler.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from transformers import set_seed
99

1010
from autointent import Dataset
11-
from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split, ValidationScheme
11+
from autointent.configs import DataConfig
12+
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabels, Split
1213

1314
from ._stratification import split_dataset
1415

@@ -32,31 +33,27 @@ class DataHandler: # TODO rename to Validator
3233
def __init__(
3334
self,
3435
dataset: Dataset,
35-
scheme: ValidationScheme = "ho",
36-
separate_nodes: bool = True,
36+
config: DataConfig | None = None,
3737
random_seed: int = 0,
38-
n_folds: int = 3,
3938
) -> None:
4039
"""
4140
Initialize the data handler.
4241
4342
:param dataset: Training dataset.
4443
:param random_seed: Seed for random number generation.
45-
:param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
46-
threshold search).
44+
:param config: config
4745
"""
4846
set_seed(random_seed)
4947
self.random_seed = random_seed
5048

5149
self.dataset = dataset
50+
self.config = config if config is not None else DataConfig()
5251

5352
self.n_classes = self.dataset.n_classes
54-
self.scheme = scheme
55-
self.n_folds = n_folds
5653

57-
if scheme == "ho":
58-
self._split_ho(separate_nodes)
59-
elif scheme == "cv":
54+
if self.config.scheme == "ho":
55+
self._split_ho(self.config.separation_ratio, self.config.validation_size)
56+
elif self.config.scheme == "cv":
6057
self._split_cv()
6158

6259
self.regex_patterns = [
@@ -120,7 +117,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
120117
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
121118

122119
def train_labels_folded(self) -> list[ListOfGenericLabels]:
123-
return [self.train_labels(j) for j in range(self.n_folds)]
120+
return [self.train_labels(j) for j in range(self.config.n_folds)]
124121

125122
def validation_utterances(self, idx: int | None = None) -> list[str]:
126123
"""
@@ -179,14 +176,14 @@ def test_labels(self) -> ListOfGenericLabels:
179176
return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature])
180177

181178
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
182-
if self.scheme == "ho":
179+
if self.config.scheme == "ho":
183180
msg = "Cannot call cross-validation on hold-out DataHandler"
184181
raise RuntimeError(msg)
185182

186-
for j in range(self.n_folds):
183+
for j in range(self.config.n_folds):
187184
val_utterances = self.train_utterances(j)
188185
val_labels = self.train_labels(j)
189-
train_folds = [i for i in range(self.n_folds) if i != j]
186+
train_folds = [i for i in range(self.config.n_folds) if i != j]
190187
train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)]
191188
train_labels = [lab for i_fold in train_folds for lab in self.train_labels(i_fold)]
192189

@@ -195,14 +192,14 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
195192
train_labels = [lab for lab in train_labels if lab is not None]
196193
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
197194

198-
def _split_ho(self, separate_nodes: bool) -> None:
195+
def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size: FloatFromZeroToOne) -> None:
199196
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
200197

201-
if separate_nodes and Split.TRAIN in self.dataset:
202-
self._split_train()
198+
if separation_ratio is not None and Split.TRAIN in self.dataset:
199+
self._split_train(separation_ratio)
203200

204201
if not has_validation_split:
205-
self._split_validation_from_train()
202+
self._split_validation_from_train(validation_size)
206203

207204
for split in self.dataset:
208205
n_classes_in_split = self.dataset.get_n_classes(split)
@@ -212,7 +209,7 @@ def _split_ho(self, separate_nodes: bool) -> None:
212209
)
213210
raise ValueError(message)
214211

215-
def _split_train(self) -> None:
212+
def _split_train(self, ratio: FloatFromZeroToOne) -> None:
216213
"""
217214
Split on two sets.
218215
@@ -221,40 +218,32 @@ def _split_train(self) -> None:
221218
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
222219
self.dataset,
223220
split=Split.TRAIN,
224-
test_size=0.5,
221+
test_size=ratio,
225222
random_seed=self.random_seed,
226223
allow_oos_in_train=False, # only train data for decision node should contain OOS
227224
)
228225
self.dataset.pop(Split.TRAIN)
229226

230227
def _split_cv(self) -> None:
231-
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
232-
if extra_splits:
233-
self.dataset[Split.TRAIN] = concatenate_datasets(
234-
[self.dataset.pop(split_name) for split_name in extra_splits]
235-
)
236-
237-
if Split.TEST not in self.dataset:
238-
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
239-
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True
240-
)
228+
extra_splits = [split_name for split_name in self.dataset if split_name != Split.TEST]
229+
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(split_name) for split_name in extra_splits])
241230

242-
for j in range(self.n_folds - 1):
231+
for j in range(self.config.n_folds - 1):
243232
self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset(
244233
self.dataset,
245234
split=Split.TRAIN,
246-
test_size=1 / (self.n_folds - j),
235+
test_size=1 / (self.config.n_folds - j),
247236
random_seed=self.random_seed,
248237
allow_oos_in_train=True,
249238
)
250-
self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
239+
self.dataset[f"{Split.TRAIN}_{self.config.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
251240

252-
def _split_validation_from_train(self) -> None:
241+
def _split_validation_from_train(self, size: float) -> None:
253242
if Split.TRAIN in self.dataset:
254243
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
255244
self.dataset,
256245
split=Split.TRAIN,
257-
test_size=0.2,
246+
test_size=size,
258247
random_seed=self.random_seed,
259248
allow_oos_in_train=True,
260249
)
@@ -263,13 +252,13 @@ def _split_validation_from_train(self) -> None:
263252
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
264253
self.dataset,
265254
split=f"{Split.TRAIN}_{idx}",
266-
test_size=0.2,
255+
test_size=size,
267256
random_seed=self.random_seed,
268257
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
269258
)
270259

271260
def prepare_for_refit(self) -> None:
272-
if self.scheme == "ho":
261+
if self.config.scheme == "ho":
273262
return
274263

275264
train_folds = [split_name for split_name in self.dataset if split_name.startswith(Split.TRAIN)]
@@ -278,7 +267,7 @@ def prepare_for_refit(self) -> None:
278267
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
279268
self.dataset,
280269
split=Split.TRAIN,
281-
test_size=0.5,
270+
test_size=self.config.separation_ratio or 0.5,
282271
random_seed=self.random_seed,
283272
allow_oos_in_train=False,
284273
)

autointent/modules/abc/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]:
4343
:param split: Split to score on
4444
:return: Computed metrics value for the test set or error code of metrics
4545
"""
46-
if context.data_handler.scheme == "ho":
46+
if context.data_handler.config.scheme == "ho":
4747
return self.score_ho(context, metrics)
48-
if context.data_handler.scheme == "cv":
48+
if context.data_handler.config.scheme == "cv":
4949
return self.score_cv(context, metrics)
5050
msg = "Something's wrong with validation schemas"
5151
raise RuntimeError(msg)

autointent/modules/abc/_decision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
7575
chosen_metrics = {name: fn for name, fn in DECISION_METRICS.items() if name in metrics}
7676
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
7777
all_val_decisions = []
78-
for j in range(context.data_handler.n_folds):
78+
for j in range(context.data_handler.config.n_folds):
7979
val_labels = labels[j]
8080
val_scores = scores[j]
81-
train_folds = [i for i in range(context.data_handler.n_folds) if i != j]
81+
train_folds = [i for i in range(context.data_handler.config.n_folds) if i != j]
8282
train_labels = [ut for i_fold in train_folds for ut in labels[i_fold]]
8383
train_scores = np.array([sc for i_fold in train_folds for sc in scores[i_fold]])
8484
self.fit(train_scores, train_labels, context.data_handler.tags) # type: ignore[arg-type]

tests/data/test_data_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from autointent import Dataset
4+
from autointent.configs import DataConfig
45
from autointent.context.data_handler import DataHandler
56
from autointent.schemas import Sample
67

@@ -180,7 +181,7 @@ def count_oos(split):
180181

181182

182183
def test_cv_folding(dataset):
183-
DataHandler(dataset, scheme="cv", n_folds=3)
184+
DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3))
184185

185186
desired_specs = {
186187
"test": {"total": 12, "oos": 4},
@@ -199,7 +200,7 @@ def count_oos_labels(split):
199200

200201

201202
def test_cv_iterator(dataset):
202-
dh = DataHandler(dataset, scheme="cv", n_folds=3)
203+
dh = DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3))
203204

204205
desired_specs = [
205206
{

0 commit comments

Comments
 (0)