From 5b4cd28dbc590b084ac793f43096a21b05682727 Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 18 Feb 2025 13:03:54 +0300 Subject: [PATCH 1/4] refactor data_handler constructor --- .../context/data_handler/_data_handler.py | 32 +++++++++---------- autointent/modules/abc/_base.py | 4 +-- autointent/modules/abc/_decision.py | 4 +-- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 507d8e0f5..9858c4ca4 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -8,7 +8,8 @@ from transformers import set_seed from autointent import Dataset -from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split, ValidationScheme +from autointent.configs import DataConfig +from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split from ._stratification import split_dataset @@ -32,10 +33,8 @@ class DataHandler: # TODO rename to Validator def __init__( self, dataset: Dataset, - scheme: ValidationScheme = "ho", - separate_nodes: bool = True, + config: DataConfig | None = None, random_seed: int = 0, - n_folds: int = 3, ) -> None: """ Initialize the data handler. @@ -49,14 +48,13 @@ def __init__( self.random_seed = random_seed self.dataset = dataset + self.config = config if config is not None else DataConfig() self.n_classes = self.dataset.n_classes - self.scheme = scheme - self.n_folds = n_folds - if scheme == "ho": - self._split_ho(separate_nodes) - elif scheme == "cv": + if self.config.scheme == "ho": + self._split_ho(self.config.separate_nodes) + elif self.config.scheme == "cv": self._split_cv() self.regex_patterns = [ @@ -120,7 +118,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels: return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) def train_labels_folded(self) -> list[ListOfGenericLabels]: - return [self.train_labels(j) for j in range(self.n_folds)] + return [self.train_labels(j) for j in range(self.config.n_folds)] def validation_utterances(self, idx: int | None = None) -> list[str]: """ @@ -177,14 +175,14 @@ def test_labels(self) -> ListOfGenericLabels: return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature]) def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]: - if self.scheme == "ho": + if self.config.scheme == "ho": msg = "Cannot call cross-validation on hold-out DataHandler" raise RuntimeError(msg) - for j in range(self.n_folds): + for j in range(self.config.n_folds): val_utterances = self.train_utterances(j) val_labels = self.train_labels(j) - train_folds = [i for i in range(self.n_folds) if i != j] + train_folds = [i for i in range(self.config.n_folds) if i != j] train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)] train_labels = [lab for i_fold in train_folds for lab in self.train_labels(i_fold)] @@ -237,15 +235,15 @@ def _split_cv(self) -> None: self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True ) - for j in range(self.n_folds - 1): + for j in range(self.config.n_folds - 1): self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset( self.dataset, split=Split.TRAIN, - test_size=1 / (self.n_folds - j), + test_size=1 / (self.config.n_folds - j), random_seed=self.random_seed, allow_oos_in_train=True, ) - self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN) + self.dataset[f"{Split.TRAIN}_{self.config.n_folds-1}"] = self.dataset.pop(Split.TRAIN) def _split_validation_from_train(self) -> None: if Split.TRAIN in self.dataset: @@ -267,7 +265,7 @@ def _split_validation_from_train(self) -> None: ) def prepare_for_refit(self) -> None: - if self.scheme == "ho": + if self.config.scheme == "ho": return train_folds = [split_name for split_name in self.dataset if split_name.startswith(Split.TRAIN)] diff --git a/autointent/modules/abc/_base.py b/autointent/modules/abc/_base.py index cd3b57142..29fad416b 100644 --- a/autointent/modules/abc/_base.py +++ b/autointent/modules/abc/_base.py @@ -43,9 +43,9 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]: :param split: Split to score on :return: Computed metrics value for the test set or error code of metrics """ - if context.data_handler.scheme == "ho": + if context.data_handler.config.scheme == "ho": return self.score_ho(context, metrics) - if context.data_handler.scheme == "cv": + if context.data_handler.config.scheme == "cv": return self.score_cv(context, metrics) msg = "Something's wrong with validation schemas" raise RuntimeError(msg) diff --git a/autointent/modules/abc/_decision.py b/autointent/modules/abc/_decision.py index 72b884dd9..f97541a29 100644 --- a/autointent/modules/abc/_decision.py +++ b/autointent/modules/abc/_decision.py @@ -75,10 +75,10 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: chosen_metrics = {name: fn for name, fn in DECISION_METRICS.items() if name in metrics} metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics} all_val_decisions = [] - for j in range(context.data_handler.n_folds): + for j in range(context.data_handler.config.n_folds): val_labels = labels[j] val_scores = scores[j] - train_folds = [i for i in range(context.data_handler.n_folds) if i != j] + train_folds = [i for i in range(context.data_handler.config.n_folds) if i != j] train_labels = [ut for i_fold in train_folds for ut in labels[i_fold]] train_scores = np.array([sc for i_fold in train_folds for sc in scores[i_fold]]) self.fit(train_scores, train_labels, context.data_handler.tags) # type: ignore[arg-type] From b5781d97ef1de00dd8e8838949761660925d2e07 Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 18 Feb 2025 13:17:56 +0300 Subject: [PATCH 2/4] finish refactoring data handler --- autointent/configs/_optimization.py | 8 ++-- .../context/data_handler/_data_handler.py | 41 ++++++++----------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/autointent/configs/_optimization.py b/autointent/configs/_optimization.py index b4b8370b5..ba3d8093c 100644 --- a/autointent/configs/_optimization.py +++ b/autointent/configs/_optimization.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, PositiveInt -from autointent.custom_types import SamplerType, ValidationScheme +from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme from ._name import get_run_name @@ -16,8 +16,10 @@ class DataConfig(BaseModel): """Hold-out or cross-validation.""" n_folds: PositiveInt = 3 """Number of folds in cross-validation.""" - separate_nodes: bool = True - """Whether to use separate data for decision node.""" + validation_size: FloatFromZeroToOne = 0.2 + """Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).""" + separation_ratio: FloatFromZeroToOne | None = 0.5 + """Set to float to prevent data leak between scoring and decision nodes.""" class TaskConfig(BaseModel): diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 9858c4ca4..03b37b644 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -9,7 +9,7 @@ from autointent import Dataset from autointent.configs import DataConfig -from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split +from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabels, Split from ._stratification import split_dataset @@ -41,8 +41,7 @@ def __init__( :param dataset: Training dataset. :param random_seed: Seed for random number generation. - :param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and - threshold search). + :param config: config """ set_seed(random_seed) self.random_seed = random_seed @@ -53,7 +52,7 @@ def __init__( self.n_classes = self.dataset.n_classes if self.config.scheme == "ho": - self._split_ho(self.config.separate_nodes) + self._split_ho(self.config.separation_ratio, self.config.validation_size) elif self.config.scheme == "cv": self._split_cv() @@ -191,14 +190,14 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s train_labels = [lab for lab in train_labels if lab is not None] yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc] - def _split_ho(self, separate_nodes: bool) -> None: + def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size: FloatFromZeroToOne) -> None: has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset) - if separate_nodes and Split.TRAIN in self.dataset: - self._split_train() + if separation_ratio is not None and Split.TRAIN in self.dataset: + self._split_train(separation_ratio) if not has_validation_split: - self._split_validation_from_train() + self._split_validation_from_train(validation_size) for split in self.dataset: n_classes_in_split = self.dataset.get_n_classes(split) @@ -208,7 +207,7 @@ def _split_ho(self, separate_nodes: bool) -> None: ) raise ValueError(message) - def _split_train(self) -> None: + def _split_train(self, ratio: FloatFromZeroToOne) -> None: """ Split on two sets. @@ -217,23 +216,17 @@ def _split_train(self) -> None: self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset( self.dataset, split=Split.TRAIN, - test_size=0.5, + test_size=ratio, random_seed=self.random_seed, allow_oos_in_train=False, # only train data for decision node should contain OOS ) self.dataset.pop(Split.TRAIN) def _split_cv(self) -> None: - extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]] - if extra_splits: - self.dataset[Split.TRAIN] = concatenate_datasets( - [self.dataset.pop(split_name) for split_name in extra_splits] - ) - - if Split.TEST not in self.dataset: - self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset( - self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True - ) + extra_splits = [split_name for split_name in self.dataset if split_name != Split.TEST] + self.dataset[Split.TRAIN] = concatenate_datasets( + [self.dataset.pop(split_name) for split_name in extra_splits] + ) for j in range(self.config.n_folds - 1): self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset( @@ -245,12 +238,12 @@ def _split_cv(self) -> None: ) self.dataset[f"{Split.TRAIN}_{self.config.n_folds-1}"] = self.dataset.pop(Split.TRAIN) - def _split_validation_from_train(self) -> None: + def _split_validation_from_train(self, size: float) -> None: if Split.TRAIN in self.dataset: self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset( self.dataset, split=Split.TRAIN, - test_size=0.2, + test_size=size, random_seed=self.random_seed, allow_oos_in_train=True, ) @@ -259,7 +252,7 @@ def _split_validation_from_train(self) -> None: self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset( self.dataset, split=f"{Split.TRAIN}_{idx}", - test_size=0.2, + test_size=size, random_seed=self.random_seed, allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train ) @@ -274,7 +267,7 @@ def prepare_for_refit(self) -> None: self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset( self.dataset, split=Split.TRAIN, - test_size=0.5, + test_size=self.config.separation_ratio or 0.5, random_seed=self.random_seed, allow_oos_in_train=False, ) From 3a6593cb8c01bcf72dd079186792661abfa5b9fd Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 18 Feb 2025 13:21:05 +0300 Subject: [PATCH 3/4] fix codestyle --- autointent/context/data_handler/_data_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 03b37b644..0b80a3adc 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -224,9 +224,7 @@ def _split_train(self, ratio: FloatFromZeroToOne) -> None: def _split_cv(self) -> None: extra_splits = [split_name for split_name in self.dataset if split_name != Split.TEST] - self.dataset[Split.TRAIN] = concatenate_datasets( - [self.dataset.pop(split_name) for split_name in extra_splits] - ) + self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(split_name) for split_name in extra_splits]) for j in range(self.config.n_folds - 1): self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset( From b6e792ab1e8f7fdff62cb5cd7f4058553058dc2d Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 18 Feb 2025 13:21:11 +0300 Subject: [PATCH 4/4] update tests --- autointent/context/_context.py | 2 +- tests/data/test_data_handler.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 75de47f83..62d37ab3f 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -65,7 +65,7 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: :param dataset: Dataset. """ - self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump()) + self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config) def get_inference_config(self) -> dict[str, Any]: """ diff --git a/tests/data/test_data_handler.py b/tests/data/test_data_handler.py index 2d134fdc5..c730ce318 100644 --- a/tests/data/test_data_handler.py +++ b/tests/data/test_data_handler.py @@ -1,6 +1,7 @@ import pytest from autointent import Dataset +from autointent.configs import DataConfig from autointent.context.data_handler import DataHandler from autointent.schemas import Sample @@ -180,7 +181,7 @@ def count_oos(split): def test_cv_folding(dataset): - DataHandler(dataset, scheme="cv", n_folds=3) + DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3)) desired_specs = { "test": {"total": 12, "oos": 4}, @@ -199,7 +200,7 @@ def count_oos_labels(split): def test_cv_iterator(dataset): - dh = DataHandler(dataset, scheme="cv", n_folds=3) + dh = DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3)) desired_specs = [ {