Skip to content

Commit b055bbc

Browse files
authored
disable data splitting on demand (#136)
* implement logic * refactor interface for configuring data * add test for no node separation optimization * upd test * upd test * fix test
1 parent 8eab2b7 commit b055bbc

File tree

10 files changed

+285
-112
lines changed

10 files changed

+285
-112
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import yaml
1010

1111
from autointent import Context, Dataset
12-
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13-
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
12+
from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13+
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType
1414
from autointent.metrics import DECISION_METRICS
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
@@ -43,11 +43,12 @@ def __init__(
4343
if isinstance(nodes[0], NodeOptimizer):
4444
self.logging_config = LoggingConfig(dump_dir=None)
4545
self.vector_index_config = VectorIndexConfig()
46+
self.data_config = DataConfig()
4647
elif not isinstance(nodes[0], InferenceNode):
4748
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
4849
raise TypeError(msg)
4950

50-
def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
51+
def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None:
5152
"""
5253
Set configuration for the optimizer.
5354
@@ -57,6 +58,8 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
5758
self.logging_config = config
5859
elif isinstance(config, VectorIndexConfig):
5960
self.vector_index_config = config
61+
elif isinstance(config, DataConfig):
62+
self.data_config = config
6063
else:
6164
msg = "unknown config type"
6265
raise TypeError(msg)
@@ -119,8 +122,6 @@ def _is_inference(self) -> bool:
119122
def fit(
120123
self,
121124
dataset: Dataset,
122-
scheme: ValidationScheme = "ho",
123-
n_folds: int = 3,
124125
refit_after: bool = False,
125126
sampler: SamplerType = "brute",
126127
) -> Context:
@@ -135,7 +136,7 @@ def fit(
135136
raise RuntimeError(msg)
136137

137138
context = Context()
138-
context.set_dataset(dataset, scheme, n_folds)
139+
context.set_dataset(dataset, self.data_config)
139140
context.configure_logging(self.logging_config)
140141
context.configure_vector_index(self.vector_index_config)
141142

autointent/configs/_optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
class DataConfig(BaseModel):
1313
"""Configuration for the data used in the optimization process."""
1414

15-
train_path: str | Path
16-
"""Path to the training data. Can be local path or HF repo."""
17-
scheme: ValidationScheme
15+
scheme: ValidationScheme = "ho"
1816
"""Hold-out or cross-validation."""
1917
n_folds: PositiveInt = 3
2018
"""Number of folds in cross-validation."""
19+
separate_nodes: bool = True
20+
"""Whether to use separate data for decision node."""
2121

2222

2323
class TaskConfig(BaseModel):

autointent/context/_context.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
LoggingConfig,
1515
VectorIndexConfig,
1616
)
17-
from autointent.custom_types import ValidationScheme
1817

19-
from ._utils import NumpyEncoder, load_dataset
18+
from ._utils import NumpyEncoder
2019
from .data_handler import DataHandler
2120
from .optimization_info import OptimizationInfo
2221

@@ -60,28 +59,13 @@ def configure_vector_index(self, config: VectorIndexConfig) -> None:
6059
"""
6160
self.vector_index_config = config
6261

63-
def configure_data(self, config: DataConfig) -> None:
64-
"""
65-
Configure data handling.
66-
67-
:param config: Configuration for the data handling process.
68-
"""
69-
self.data_handler = DataHandler(
70-
dataset=load_dataset(config.train_path), random_seed=self.seed, scheme=config.scheme
71-
)
72-
73-
def set_dataset(self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3) -> None:
62+
def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
7463
"""
7564
Set the datasets for training, validation and testing.
7665
7766
:param dataset: Dataset.
7867
"""
79-
self.data_handler = DataHandler(
80-
dataset=dataset,
81-
random_seed=self.seed,
82-
scheme=scheme,
83-
n_folds=n_folds,
84-
)
68+
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump())
8569

8670
def get_inference_config(self) -> dict[str, Any]:
8771
"""

autointent/context/data_handler/_data_handler.py

Lines changed: 25 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
dataset: Dataset,
3535
scheme: ValidationScheme = "ho",
36-
split_train: bool = True,
36+
separate_nodes: bool = True,
3737
random_seed: int = 0,
3838
n_folds: int = 3,
3939
) -> None:
@@ -42,7 +42,7 @@ def __init__(
4242
4343
:param dataset: Training dataset.
4444
:param random_seed: Seed for random number generation.
45-
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
45+
:param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
4646
threshold search).
4747
"""
4848
set_seed(random_seed)
@@ -55,7 +55,7 @@ def __init__(
5555
self.n_folds = n_folds
5656

5757
if scheme == "ho":
58-
self._split_ho(split_train)
58+
self._split_ho(separate_nodes)
5959
elif scheme == "cv":
6060
self._split_cv()
6161

@@ -82,6 +82,15 @@ def multilabel(self) -> bool:
8282
"""
8383
return self.dataset.multilabel
8484

85+
def _choose_split(self, split_name: str, idx: int | None = None) -> str:
86+
if idx is not None:
87+
split = f"{split_name}_{idx}"
88+
if split not in self.dataset:
89+
split = split_name
90+
else:
91+
split = split_name
92+
return split
93+
8594
def train_utterances(self, idx: int | None = None) -> list[str]:
8695
"""
8796
Retrieve training utterances from the dataset.
@@ -93,7 +102,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]:
93102
:param idx: Optional index for a specific training split.
94103
:return: List of training utterances.
95104
"""
96-
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
105+
split = self._choose_split(Split.TRAIN, idx)
97106
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
98107

99108
def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
@@ -107,7 +116,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
107116
:param idx: Optional index for a specific training split.
108117
:return: List of training labels.
109118
"""
110-
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
119+
split = self._choose_split(Split.TRAIN, idx)
111120
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
112121

113122
def train_labels_folded(self) -> list[ListOfGenericLabels]:
@@ -124,7 +133,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]:
124133
:param idx: Optional index for a specific validation split.
125134
:return: List of validation utterances.
126135
"""
127-
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
136+
split = self._choose_split(Split.VALIDATION, idx)
128137
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
129138

130139
def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
@@ -138,10 +147,10 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
138147
:param idx: Optional index for a specific validation split.
139148
:return: List of validation labels.
140149
"""
141-
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
150+
split = self._choose_split(Split.VALIDATION, idx)
142151
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
143152

144-
def test_utterances(self, idx: int | None = None) -> list[str]:
153+
def test_utterances(self) -> list[str]:
145154
"""
146155
Retrieve test utterances from the dataset.
147156
@@ -152,10 +161,9 @@ def test_utterances(self, idx: int | None = None) -> list[str]:
152161
:param idx: Optional index for a specific test split.
153162
:return: List of test utterances.
154163
"""
155-
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
156-
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
164+
return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature])
157165

158-
def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
166+
def test_labels(self) -> ListOfGenericLabels:
159167
"""
160168
Retrieve test labels from the dataset.
161169
@@ -166,8 +174,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
166174
:param idx: Optional index for a specific test split.
167175
:return: List of test labels.
168176
"""
169-
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
170-
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
177+
return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature])
171178

172179
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
173180
if self.scheme == "ho":
@@ -186,27 +193,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
186193
train_labels = [lab for lab in train_labels if lab is not None]
187194
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
188195

189-
def _split_ho(self, split_train: bool) -> None:
196+
def _split_ho(self, separate_nodes: bool) -> None:
190197
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
191198

192-
if split_train and Split.TRAIN in self.dataset:
199+
if separate_nodes and Split.TRAIN in self.dataset:
193200
self._split_train()
194201

195-
if Split.TEST not in self.dataset:
196-
test_size = 0.1 if has_validation_split else 0.2
197-
self._split_test(test_size)
198-
199202
if not has_validation_split:
200203
self._split_validation_from_train()
201-
elif Split.VALIDATION in self.dataset:
202-
self._split_validation()
203204

204205
for split in self.dataset:
205-
n_classes_split = self.dataset.get_n_classes(split)
206-
if n_classes_split != self.n_classes:
206+
n_classes_in_split = self.dataset.get_n_classes(split)
207+
if n_classes_in_split != self.n_classes:
207208
message = (
208-
f"Number of classes in split '{split}' doesn't match initial number of classes "
209-
f"({n_classes_split} != {self.n_classes})"
209+
f"{n_classes_in_split=} for '{split=}' doesn't match initial number of classes ({self.n_classes})"
210210
)
211211
raise ValueError(message)
212212

@@ -225,30 +225,6 @@ def _split_train(self) -> None:
225225
)
226226
self.dataset.pop(Split.TRAIN)
227227

228-
def _split_validation(self) -> None:
229-
"""
230-
Split on two sets.
231-
232-
One is for scoring node optimizaton, one is for decision node.
233-
"""
234-
self.dataset[f"{Split.VALIDATION}_0"], self.dataset[f"{Split.VALIDATION}_1"] = split_dataset(
235-
self.dataset,
236-
split=Split.VALIDATION,
237-
test_size=0.5,
238-
random_seed=self.random_seed,
239-
allow_oos_in_train=False, # only val data for decision node should contain OOS
240-
)
241-
self.dataset.pop(Split.VALIDATION)
242-
243-
def _split_validation_from_test(self) -> None:
244-
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
245-
self.dataset,
246-
split=Split.TEST,
247-
test_size=0.5,
248-
random_seed=self.random_seed,
249-
allow_oos_in_train=True, # both test and validation splits can contain OOS
250-
)
251-
252228
def _split_cv(self) -> None:
253229
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
254230
if extra_splits:
@@ -290,27 +266,6 @@ def _split_validation_from_train(self) -> None:
290266
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
291267
)
292268

293-
def _split_test(self, test_size: float) -> None:
294-
"""Obtain test set from train."""
295-
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
296-
self.dataset,
297-
split=f"{Split.TRAIN}_0",
298-
test_size=test_size,
299-
random_seed=self.random_seed,
300-
)
301-
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
302-
self.dataset,
303-
split=f"{Split.TRAIN}_1",
304-
test_size=test_size,
305-
random_seed=self.random_seed,
306-
allow_oos_in_train=True,
307-
)
308-
self.dataset[Split.TEST] = concatenate_datasets(
309-
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
310-
)
311-
self.dataset.pop(f"{Split.TEST}_0")
312-
self.dataset.pop(f"{Split.TEST}_1")
313-
314269
def prepare_for_refit(self) -> None:
315270
if self.scheme == "ho":
316271
return

tests/assets/configs/light.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
- node_type: embedding
2+
target_metric: retrieval_hit_rate
3+
search_space:
4+
- module_name: retrieval
5+
k: [10]
6+
embedder_config:
7+
- model_name: sentence-transformers/all-MiniLM-L6-v2
8+
- node_type: scoring
9+
target_metric: scoring_roc_auc
10+
search_space:
11+
- module_name: linear
12+
- node_type: decision
13+
target_metric: decision_accuracy
14+
search_space:
15+
- module_name: argmax

0 commit comments

Comments
 (0)