Skip to content

Commit 0979234

Browse files
authored
Feat/cross validation (#121)
* define interface * basic ho iterator * move obtaining data for train from node optimizer to modules themselves * stage progress * implement cv iterator * minor bug fix * implement cv iterator for decision node * move cv iteration to base module definition * implement cv iterator for embedding node * add training to `score_ho` of each node * properly define base module * fix codestyle * remove regexp node * remove regexp validator * fix typing problems (except `DataHandler._split_cv`) * add ingore oos decorator * fix codestyle * fix typing * add oos handling to cv iterator * remove `DataHandler.dump()` * minor bug fix * implement splitting to cv folds * fix codestyle * remove regex tests * bug fix * bug fix * update tests * fix typing * big fix * basic test on cv folding * add tests for metrics to ignore oos samples * add tests for cv iterator * fix codestyle * minor bug fix * fix codestyle * add test for cv * bug fix * implement cv iterator for description scorer * refactor cv iterator for description node * fix typing * add cache cleaning before refitting * bug fix * implement refitting the whole pipeline with all train data * fix typing * bug fix * fix typing * respond to samoed * create `ValidationType` in `autointent.custom_types` * fix docstring * properly expose `n_folds` argument * `ValidationType` -> `ValidationScheme` * `make schema`
1 parent 6a478cd commit 0979234

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+545
-629
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from autointent import Context, Dataset
1212
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13-
from autointent.custom_types import ListOfGenericLabels, NodeType
13+
from autointent.custom_types import ListOfGenericLabels, NodeType, ValidationScheme
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
@@ -122,7 +122,9 @@ def _is_inference(self) -> bool:
122122
"""
123123
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
124124

125-
def fit(self, dataset: Dataset) -> Context:
125+
def fit(
126+
self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3, refit_after: bool = False
127+
) -> Context:
126128
"""
127129
Optimize the pipeline from dataset.
128130
@@ -134,7 +136,7 @@ def fit(self, dataset: Dataset) -> Context:
134136
raise RuntimeError(msg)
135137

136138
context = Context()
137-
context.set_dataset(dataset)
139+
context.set_dataset(dataset, scheme, n_folds)
138140
context.configure_logging(self.logging_config)
139141
context.configure_vector_index(self.vector_index_config, self.embedder_config)
140142
context.configure_cross_encoder(self.cross_encoder_config)
@@ -150,6 +152,9 @@ def fit(self, dataset: Dataset) -> Context:
150152

151153
self.nodes = {node.node_type: node for node in nodes_list}
152154

155+
if refit_after:
156+
self._refit(context)
157+
153158
predictions = self.predict(context.data_handler.test_utterances())
154159
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
155160
context.optimization_info.pipeline_metrics[metric_name] = metric(
@@ -220,6 +225,27 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels:
220225
scores = scoring_module.predict(utterances)
221226
return decision_module.predict(scores)
222227

228+
def _refit(self, context: Context) -> None:
229+
"""
230+
Fit pipeline of already selected modules with all train data.
231+
232+
:param context: context object to take data from
233+
:return: list of predicted labels
234+
"""
235+
if not self._is_inference():
236+
msg = "Pipeline in optimization mode cannot perform inference"
237+
raise RuntimeError(msg)
238+
239+
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
240+
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
241+
242+
context.data_handler.prepare_for_refit()
243+
244+
scoring_module.fit(*scoring_module.get_train_data(context))
245+
scores = scoring_module.predict(context.data_handler.train_utterances(1))
246+
247+
decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags)
248+
223249
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
224250
"""
225251
Predict the labels for the utterances with metadata.

autointent/_ranker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Can be used to rank retrieved sentences by meaning closeness to provided utterance.
44
"""
55

6+
import gc
67
import itertools as it
78
import json
89
import logging
@@ -272,3 +273,9 @@ def load(cls, path: Path) -> "Ranker":
272273
metadata: CrossEncoderMetadata = json.load(file)
273274

274275
return cls(**metadata, classifier_head=clf)
276+
277+
def clear_ram(self) -> None:
278+
self.cross_encoder.model.cpu()
279+
del self.cross_encoder
280+
gc.collect()
281+
torch.cuda.empty_cache()

autointent/configs/_optimization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from pydantic import BaseModel, Field
66

7+
from autointent.custom_types import ValidationScheme
8+
79
from ._name import get_run_name
810

911

@@ -12,6 +14,10 @@ class DataConfig(BaseModel):
1214

1315
train_path: str | Path
1416
"""Path to the training data. Can be local path or HF repo."""
17+
scheme: ValidationScheme
18+
"""Hold-out or cross-validation."""
19+
n_folds: int = 3
20+
"""Number of folds in cross-validation."""
1521

1622

1723
class TaskConfig(BaseModel):

autointent/context/_context.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
LoggingConfig,
1717
VectorIndexConfig,
1818
)
19+
from autointent.custom_types import ValidationScheme
1920

2021
from ._utils import NumpyEncoder, load_dataset
2122
from .data_handler import DataHandler
@@ -81,11 +82,10 @@ def configure_data(self, config: DataConfig) -> None:
8182
:param config: Configuration for the data handling process.
8283
"""
8384
self.data_handler = DataHandler(
84-
dataset=load_dataset(config.train_path),
85-
random_seed=self.seed,
85+
dataset=load_dataset(config.train_path), random_seed=self.seed, scheme=config.scheme
8686
)
8787

88-
def set_dataset(self, dataset: Dataset) -> None:
88+
def set_dataset(self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3) -> None:
8989
"""
9090
Set the datasets for training, validation and testing.
9191
@@ -94,6 +94,8 @@ def set_dataset(self, dataset: Dataset) -> None:
9494
self.data_handler = DataHandler(
9595
dataset=dataset,
9696
random_seed=self.seed,
97+
scheme=scheme,
98+
n_folds=n_folds,
9799
)
98100

99101
def get_inference_config(self) -> dict[str, Any]:
@@ -137,7 +139,7 @@ def dump(self) -> None:
137139
# self._logger.info(make_report(optimization_results, nodes=nodes))
138140

139141
# dump train and test data splits
140-
self.data_handler.dump(logs_dir / "dataset.json")
142+
self.data_handler.dataset.to_json(logs_dir / "dataset.json")
141143

142144
self._logger.info("logs and other assets are saved to %s", logs_dir)
143145

autointent/context/data_handler/_data_handler.py

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Data Handler file."""
22

33
import logging
4-
from pathlib import Path
4+
from collections.abc import Generator
55
from typing import TypedDict, cast
66

77
from datasets import concatenate_datasets
88
from transformers import set_seed
99

1010
from autointent import Dataset
11-
from autointent.custom_types import ListOfGenericLabels, Split
11+
from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split, ValidationScheme
1212

1313
from ._stratification import split_dataset
1414

@@ -26,10 +26,17 @@ class RegexPatterns(TypedDict):
2626
"""Partial match regex patterns."""
2727

2828

29-
class DataHandler:
29+
class DataHandler: # TODO rename to Validator
3030
"""Data handler class."""
3131

32-
def __init__(self, dataset: Dataset, random_seed: int = 0, split_train: bool = True) -> None:
32+
def __init__(
33+
self,
34+
dataset: Dataset,
35+
scheme: ValidationScheme = "ho",
36+
split_train: bool = True,
37+
random_seed: int = 0,
38+
n_folds: int = 3,
39+
) -> None:
3340
"""
3441
Initialize the data handler.
3542
@@ -39,12 +46,18 @@ def __init__(self, dataset: Dataset, random_seed: int = 0, split_train: bool = T
3946
threshold search).
4047
"""
4148
set_seed(random_seed)
49+
self.random_seed = random_seed
4250

4351
self.dataset = dataset
4452

4553
self.n_classes = self.dataset.n_classes
54+
self.scheme = scheme
55+
self.n_folds = n_folds
4656

47-
self._split(random_seed, split_train)
57+
if scheme == "ho":
58+
self._split_ho(split_train)
59+
elif scheme == "cv":
60+
self._split_cv()
4861

4962
self.regexp_patterns = [
5063
RegexPatterns(
@@ -97,6 +110,9 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
97110
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
98111
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
99112

113+
def train_labels_folded(self) -> list[ListOfGenericLabels]:
114+
return [self.train_labels(j) for j in range(self.n_folds)]
115+
100116
def validation_utterances(self, idx: int | None = None) -> list[str]:
101117
"""
102118
Retrieve validation utterances from the dataset.
@@ -153,28 +169,37 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
153169
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
154170
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
155171

156-
def dump(self, filepath: str | Path) -> None:
157-
"""
158-
Save the dataset splits and intents to a JSON file.
172+
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
173+
if self.scheme == "ho":
174+
msg = "Cannot call cross-validation on hold-out DataHandler"
175+
raise RuntimeError(msg)
159176

160-
:param filepath: The path to the file where the JSON data will be saved.
161-
"""
162-
self.dataset.to_json(filepath)
177+
for j in range(self.n_folds):
178+
val_utterances = self.train_utterances(j)
179+
val_labels = self.train_labels(j)
180+
train_folds = [i for i in range(self.n_folds) if i != j]
181+
train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)]
182+
train_labels = [lab for i_fold in train_folds for lab in self.train_labels(i_fold)]
163183

164-
def _split(self, random_seed: int, split_train: bool) -> None:
184+
# filter out all OOS samples from train
185+
train_utterances = [ut for ut, lab in zip(train_utterances, train_labels, strict=True) if lab is not None]
186+
train_labels = [lab for lab in train_labels if lab is not None]
187+
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
188+
189+
def _split_ho(self, split_train: bool) -> None:
165190
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
166191

167192
if split_train and Split.TRAIN in self.dataset:
168-
self._split_train(random_seed)
193+
self._split_train()
169194

170195
if Split.TEST not in self.dataset:
171196
test_size = 0.1 if has_validation_split else 0.2
172-
self._split_test(test_size, random_seed)
197+
self._split_test(test_size)
173198

174199
if not has_validation_split:
175-
self._split_validation_from_train(random_seed)
200+
self._split_validation_from_train()
176201
elif Split.VALIDATION in self.dataset:
177-
self._split_validation(random_seed)
202+
self._split_validation()
178203

179204
for split in self.dataset:
180205
n_classes_split = self.dataset.get_n_classes(split)
@@ -185,7 +210,7 @@ def _split(self, random_seed: int, split_train: bool) -> None:
185210
)
186211
raise ValueError(message)
187212

188-
def _split_train(self, random_seed: int) -> None:
213+
def _split_train(self) -> None:
189214
"""
190215
Split on two sets.
191216
@@ -195,12 +220,12 @@ def _split_train(self, random_seed: int) -> None:
195220
self.dataset,
196221
split=Split.TRAIN,
197222
test_size=0.5,
198-
random_seed=random_seed,
223+
random_seed=self.random_seed,
199224
allow_oos_in_train=False, # only train data for decision node should contain OOS
200225
)
201226
self.dataset.pop(Split.TRAIN)
202227

203-
def _split_validation(self, random_seed: int) -> None:
228+
def _split_validation(self) -> None:
204229
"""
205230
Split on two sets.
206231
@@ -210,27 +235,49 @@ def _split_validation(self, random_seed: int) -> None:
210235
self.dataset,
211236
split=Split.VALIDATION,
212237
test_size=0.5,
213-
random_seed=random_seed,
238+
random_seed=self.random_seed,
214239
allow_oos_in_train=False, # only val data for decision node should contain OOS
215240
)
216241
self.dataset.pop(Split.VALIDATION)
217242

218-
def _split_validation_from_test(self, random_seed: int) -> None:
243+
def _split_validation_from_test(self) -> None:
219244
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
220245
self.dataset,
221246
split=Split.TEST,
222247
test_size=0.5,
223-
random_seed=random_seed,
248+
random_seed=self.random_seed,
224249
allow_oos_in_train=True, # both test and validation splits can contain OOS
225250
)
226251

227-
def _split_validation_from_train(self, random_seed: int) -> None:
252+
def _split_cv(self) -> None:
253+
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
254+
if extra_splits:
255+
self.dataset[Split.TRAIN] = concatenate_datasets(
256+
[self.dataset.pop(split_name) for split_name in extra_splits]
257+
)
258+
259+
if Split.TEST not in self.dataset:
260+
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
261+
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True
262+
)
263+
264+
for j in range(self.n_folds - 1):
265+
self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset(
266+
self.dataset,
267+
split=Split.TRAIN,
268+
test_size=1 / (self.n_folds - j),
269+
random_seed=self.random_seed,
270+
allow_oos_in_train=True,
271+
)
272+
self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
273+
274+
def _split_validation_from_train(self) -> None:
228275
if Split.TRAIN in self.dataset:
229276
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
230277
self.dataset,
231278
split=Split.TRAIN,
232279
test_size=0.2,
233-
random_seed=random_seed,
280+
random_seed=self.random_seed,
234281
allow_oos_in_train=True,
235282
)
236283
else:
@@ -239,27 +286,44 @@ def _split_validation_from_train(self, random_seed: int) -> None:
239286
self.dataset,
240287
split=f"{Split.TRAIN}_{idx}",
241288
test_size=0.2,
242-
random_seed=random_seed,
289+
random_seed=self.random_seed,
243290
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
244291
)
245292

246-
def _split_test(self, test_size: float, random_seed: int) -> None:
293+
def _split_test(self, test_size: float) -> None:
247294
"""Obtain test set from train."""
248295
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
249296
self.dataset,
250297
split=f"{Split.TRAIN}_0",
251298
test_size=test_size,
252-
random_seed=random_seed,
299+
random_seed=self.random_seed,
253300
)
254301
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
255302
self.dataset,
256303
split=f"{Split.TRAIN}_1",
257304
test_size=test_size,
258-
random_seed=random_seed,
305+
random_seed=self.random_seed,
259306
allow_oos_in_train=True,
260307
)
261308
self.dataset[Split.TEST] = concatenate_datasets(
262309
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
263310
)
264311
self.dataset.pop(f"{Split.TEST}_0")
265312
self.dataset.pop(f"{Split.TEST}_1")
313+
314+
def prepare_for_refit(self) -> None:
315+
if self.scheme == "ho":
316+
return
317+
318+
train_folds = [split_name for split_name in self.dataset if split_name.startswith(Split.TRAIN)]
319+
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(name) for name in train_folds])
320+
321+
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
322+
self.dataset,
323+
split=Split.TRAIN,
324+
test_size=0.5,
325+
random_seed=self.random_seed,
326+
allow_oos_in_train=False,
327+
)
328+
329+
self.dataset.pop(Split.TRAIN)

autointent/context/optimization_info/_data_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class ScorerArtifact(Artifact):
4242
train_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for train utterances")
4343
validation_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for validation utterances")
4444
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
45+
folded_scores: list[NDArray[np.float64]] | None = Field(
46+
None, description="Scores for each fold from cross-validation"
47+
)
4548

4649

4750
class DecisionArtifact(Artifact):

0 commit comments

Comments
 (0)