Skip to content

Commit 89c6406

Browse files
committed
implement refitting the whole pipeline with all train data
1 parent 5cbf83e commit 89c6406

File tree

3 files changed

+68
-24
lines changed

3 files changed

+68
-24
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _is_inference(self) -> bool:
122122
"""
123123
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
124124

125-
def fit(self, dataset: Dataset, scheme: Literal["ho", "cv"] = "ho") -> Context:
125+
def fit(self, dataset: Dataset, scheme: Literal["ho", "cv"] = "ho", refit_after: bool = False) -> Context:
126126
"""
127127
Optimize the pipeline from dataset.
128128
@@ -150,6 +150,9 @@ def fit(self, dataset: Dataset, scheme: Literal["ho", "cv"] = "ho") -> Context:
150150

151151
self.nodes = {node.node_type: node for node in nodes_list}
152152

153+
if refit_after:
154+
self._refit(context)
155+
153156
predictions = self.predict(context.data_handler.test_utterances())
154157
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
155158
context.optimization_info.pipeline_metrics[metric_name] = metric(
@@ -210,6 +213,27 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels:
210213
scores = scoring_module.predict(utterances)
211214
return decision_module.predict(scores)
212215

216+
def _refit(self, context: Context) -> None:
217+
"""
218+
Fit pipeline of already selected modules with all train data.
219+
220+
:param utterances: list of utterances
221+
:return: list of predicted labels
222+
"""
223+
if not self._is_inference():
224+
msg = "Pipeline in optimization mode cannot perform inference"
225+
raise RuntimeError(msg)
226+
227+
scoring_module: ScoringModule = self.nodes[NodeType.scoring].module # type: ignore[assignment,union-attr]
228+
decision_module: DecisionModule = self.nodes[NodeType.decision].module # type: ignore[assignment,union-attr]
229+
230+
context.data_handler.prepare_for_refit()
231+
232+
scoring_module.fit(context.data_handler.train_utterances(0), context.data_handler.train_labels(0))
233+
scores = scoring_module.predict(context.data_handler.train_utterances(1))
234+
235+
decision_module.fit(scores, context.data_handler.train_labels(1))
236+
213237
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
214238
"""
215239
Predict the labels for the utterances with metadata.

autointent/context/data_handler/_data_handler.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
threshold search).
4747
"""
4848
set_seed(random_seed)
49+
self.random_seed = random_seed
4950

5051
self.dataset = dataset
5152

@@ -54,9 +55,9 @@ def __init__(
5455
self.n_folds = n_folds
5556

5657
if scheme == "ho":
57-
self._split_ho(random_seed, split_train)
58+
self._split_ho(split_train)
5859
elif scheme == "cv":
59-
self._split_cv(random_seed)
60+
self._split_cv()
6061

6162
self.regexp_patterns = [
6263
RegexPatterns(
@@ -185,20 +186,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
185186
train_labels = [lab for lab in train_labels if lab is not None]
186187
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
187188

188-
def _split_ho(self, random_seed: int, split_train: bool) -> None:
189+
def _split_ho(self, split_train: bool) -> None:
189190
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
190191

191192
if split_train and Split.TRAIN in self.dataset:
192-
self._split_train(random_seed)
193+
self._split_train()
193194

194195
if Split.TEST not in self.dataset:
195196
test_size = 0.1 if has_validation_split else 0.2
196-
self._split_test(test_size, random_seed)
197+
self._split_test(test_size)
197198

198199
if not has_validation_split:
199-
self._split_validation_from_train(random_seed)
200+
self._split_validation_from_train()
200201
elif Split.VALIDATION in self.dataset:
201-
self._split_validation(random_seed)
202+
self._split_validation()
202203

203204
for split in self.dataset:
204205
n_classes_split = self.dataset.get_n_classes(split)
@@ -209,7 +210,7 @@ def _split_ho(self, random_seed: int, split_train: bool) -> None:
209210
)
210211
raise ValueError(message)
211212

212-
def _split_train(self, random_seed: int) -> None:
213+
def _split_train(self) -> None:
213214
"""
214215
Split on two sets.
215216
@@ -219,12 +220,12 @@ def _split_train(self, random_seed: int) -> None:
219220
self.dataset,
220221
split=Split.TRAIN,
221222
test_size=0.5,
222-
random_seed=random_seed,
223+
random_seed=self.random_seed,
223224
allow_oos_in_train=False, # only train data for decision node should contain OOS
224225
)
225226
self.dataset.pop(Split.TRAIN)
226227

227-
def _split_validation(self, random_seed: int) -> None:
228+
def _split_validation(self) -> None:
228229
"""
229230
Split on two sets.
230231
@@ -234,21 +235,21 @@ def _split_validation(self, random_seed: int) -> None:
234235
self.dataset,
235236
split=Split.VALIDATION,
236237
test_size=0.5,
237-
random_seed=random_seed,
238+
random_seed=self.random_seed,
238239
allow_oos_in_train=False, # only val data for decision node should contain OOS
239240
)
240241
self.dataset.pop(Split.VALIDATION)
241242

242-
def _split_validation_from_test(self, random_seed: int) -> None:
243+
def _split_validation_from_test(self) -> None:
243244
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
244245
self.dataset,
245246
split=Split.TEST,
246247
test_size=0.5,
247-
random_seed=random_seed,
248+
random_seed=self.random_seed,
248249
allow_oos_in_train=True, # both test and validation splits can contain OOS
249250
)
250251

251-
def _split_cv(self, random_seed: int) -> None:
252+
def _split_cv(self) -> None:
252253
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
253254
if extra_splits:
254255
self.dataset[Split.TRAIN] = concatenate_datasets(
@@ -257,26 +258,26 @@ def _split_cv(self, random_seed: int) -> None:
257258

258259
if Split.TEST not in self.dataset:
259260
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
260-
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=random_seed, allow_oos_in_train=True
261+
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True
261262
)
262263

263264
for j in range(self.n_folds - 1):
264265
self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset(
265266
self.dataset,
266267
split=Split.TRAIN,
267268
test_size=1 / (self.n_folds - j),
268-
random_seed=random_seed,
269+
random_seed=self.random_seed,
269270
allow_oos_in_train=True,
270271
)
271272
self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
272273

273-
def _split_validation_from_train(self, random_seed: int) -> None:
274+
def _split_validation_from_train(self) -> None:
274275
if Split.TRAIN in self.dataset:
275276
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
276277
self.dataset,
277278
split=Split.TRAIN,
278279
test_size=0.2,
279-
random_seed=random_seed,
280+
random_seed=self.random_seed,
280281
allow_oos_in_train=True,
281282
)
282283
else:
@@ -285,27 +286,46 @@ def _split_validation_from_train(self, random_seed: int) -> None:
285286
self.dataset,
286287
split=f"{Split.TRAIN}_{idx}",
287288
test_size=0.2,
288-
random_seed=random_seed,
289+
random_seed=self.random_seed,
289290
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
290291
)
291292

292-
def _split_test(self, test_size: float, random_seed: int) -> None:
293+
def _split_test(self, test_size: float) -> None:
293294
"""Obtain test set from train."""
294295
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
295296
self.dataset,
296297
split=f"{Split.TRAIN}_0",
297298
test_size=test_size,
298-
random_seed=random_seed,
299+
random_seed=self.random_seed,
299300
)
300301
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
301302
self.dataset,
302303
split=f"{Split.TRAIN}_1",
303304
test_size=test_size,
304-
random_seed=random_seed,
305+
random_seed=self.random_seed,
305306
allow_oos_in_train=True,
306307
)
307308
self.dataset[Split.TEST] = concatenate_datasets(
308309
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
309310
)
310311
self.dataset.pop(f"{Split.TEST}_0")
311312
self.dataset.pop(f"{Split.TEST}_1")
313+
314+
def prepare_for_refit(self) -> None:
315+
if self.scheme == "ho":
316+
return
317+
318+
train_folds = [split_name for split_name in self.dataset if split_name.startswith("train")]
319+
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset[name] for name in train_folds])
320+
for name in train_folds:
321+
self.dataset.pop(name)
322+
323+
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
324+
self.dataset,
325+
split=Split.TRAIN,
326+
test_size=0.5,
327+
random_seed=self.random_seed,
328+
allow_oos_in_train=False,
329+
)
330+
331+
self.dataset.pop(Split.TRAIN)

tests/pipeline/test_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_cv(dataset, task_type):
3131
if task_type == "multilabel":
3232
dataset = dataset.to_multilabel()
3333

34-
context = pipeline_optimizer.fit(dataset, scheme="cv")
34+
context = pipeline_optimizer.fit(dataset, scheme="cv", refit_after=True)
3535
context.dump()
3636

3737
assert os.listdir(pipeline_optimizer.logging_config.dump_dir)

0 commit comments

Comments
 (0)