Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, Field, PositiveInt

from autointent.custom_types import SamplerType, ValidationScheme
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme

from ._name import get_run_name

Expand All @@ -16,8 +16,10 @@ class DataConfig(BaseModel):
"""Hold-out or cross-validation."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Надо навеное будет добаваить DataConfig в общий конфиг

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

если ты про OptimizationConfig в конце этого файла, то он нигде сейчас не используется

если ты про серч спейс, то это все таки разное

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я про dataconfig

n_folds: PositiveInt = 3
"""Number of folds in cross-validation."""
separate_nodes: bool = True
"""Whether to use separate data for decision node."""
validation_size: FloatFromZeroToOne = 0.2
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
separation_ratio: FloatFromZeroToOne | None = 0.5
"""Set to float to prevent data leak between scoring and decision nodes."""


class TaskConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:

:param dataset: Dataset.
"""
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump())
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config)

def get_inference_config(self) -> dict[str, Any]:
"""
Expand Down
67 changes: 28 additions & 39 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from transformers import set_seed

from autointent import Dataset
from autointent.custom_types import ListOfGenericLabels, ListOfLabels, Split, ValidationScheme
from autointent.configs import DataConfig
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabels, Split

from ._stratification import split_dataset

Expand All @@ -32,31 +33,27 @@ class DataHandler: # TODO rename to Validator
def __init__(
self,
dataset: Dataset,
scheme: ValidationScheme = "ho",
separate_nodes: bool = True,
config: DataConfig | None = None,
random_seed: int = 0,
n_folds: int = 3,
) -> None:
"""
Initialize the data handler.

:param dataset: Training dataset.
:param random_seed: Seed for random number generation.
:param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
threshold search).
:param config: config
"""
set_seed(random_seed)
self.random_seed = random_seed

self.dataset = dataset
self.config = config if config is not None else DataConfig()

self.n_classes = self.dataset.n_classes
self.scheme = scheme
self.n_folds = n_folds

if scheme == "ho":
self._split_ho(separate_nodes)
elif scheme == "cv":
if self.config.scheme == "ho":
self._split_ho(self.config.separation_ratio, self.config.validation_size)
elif self.config.scheme == "cv":
self._split_cv()

self.regex_patterns = [
Expand Down Expand Up @@ -120,7 +117,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])

def train_labels_folded(self) -> list[ListOfGenericLabels]:
return [self.train_labels(j) for j in range(self.n_folds)]
return [self.train_labels(j) for j in range(self.config.n_folds)]

def validation_utterances(self, idx: int | None = None) -> list[str]:
"""
Expand Down Expand Up @@ -177,14 +174,14 @@ def test_labels(self) -> ListOfGenericLabels:
return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature])

def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
if self.scheme == "ho":
if self.config.scheme == "ho":
msg = "Cannot call cross-validation on hold-out DataHandler"
raise RuntimeError(msg)

for j in range(self.n_folds):
for j in range(self.config.n_folds):
val_utterances = self.train_utterances(j)
val_labels = self.train_labels(j)
train_folds = [i for i in range(self.n_folds) if i != j]
train_folds = [i for i in range(self.config.n_folds) if i != j]
train_utterances = [ut for i_fold in train_folds for ut in self.train_utterances(i_fold)]
train_labels = [lab for i_fold in train_folds for lab in self.train_labels(i_fold)]

Expand All @@ -193,14 +190,14 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
train_labels = [lab for lab in train_labels if lab is not None]
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]

def _split_ho(self, separate_nodes: bool) -> None:
def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size: FloatFromZeroToOne) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)

if separate_nodes and Split.TRAIN in self.dataset:
self._split_train()
if separation_ratio is not None and Split.TRAIN in self.dataset:
self._split_train(separation_ratio)

if not has_validation_split:
self._split_validation_from_train()
self._split_validation_from_train(validation_size)

for split in self.dataset:
n_classes_in_split = self.dataset.get_n_classes(split)
Expand All @@ -210,7 +207,7 @@ def _split_ho(self, separate_nodes: bool) -> None:
)
raise ValueError(message)

def _split_train(self) -> None:
def _split_train(self, ratio: FloatFromZeroToOne) -> None:
"""
Split on two sets.

Expand All @@ -219,40 +216,32 @@ def _split_train(self) -> None:
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=0.5,
test_size=ratio,
random_seed=self.random_seed,
allow_oos_in_train=False, # only train data for decision node should contain OOS
)
self.dataset.pop(Split.TRAIN)

def _split_cv(self) -> None:
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
if extra_splits:
self.dataset[Split.TRAIN] = concatenate_datasets(
[self.dataset.pop(split_name) for split_name in extra_splits]
)

if Split.TEST not in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.TEST] = split_dataset(
self.dataset, split=Split.TRAIN, test_size=0.2, random_seed=self.random_seed, allow_oos_in_train=True
)
extra_splits = [split_name for split_name in self.dataset if split_name != Split.TEST]
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(split_name) for split_name in extra_splits])

for j in range(self.n_folds - 1):
for j in range(self.config.n_folds - 1):
self.dataset[Split.TRAIN], self.dataset[f"{Split.TRAIN}_{j}"] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=1 / (self.n_folds - j),
test_size=1 / (self.config.n_folds - j),
random_seed=self.random_seed,
allow_oos_in_train=True,
)
self.dataset[f"{Split.TRAIN}_{self.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
self.dataset[f"{Split.TRAIN}_{self.config.n_folds-1}"] = self.dataset.pop(Split.TRAIN)

def _split_validation_from_train(self) -> None:
def _split_validation_from_train(self, size: float) -> None:
if Split.TRAIN in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=0.2,
test_size=size,
random_seed=self.random_seed,
allow_oos_in_train=True,
)
Expand All @@ -261,13 +250,13 @@ def _split_validation_from_train(self) -> None:
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_{idx}",
test_size=0.2,
test_size=size,
random_seed=self.random_seed,
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
)

def prepare_for_refit(self) -> None:
if self.scheme == "ho":
if self.config.scheme == "ho":
return

train_folds = [split_name for split_name in self.dataset if split_name.startswith(Split.TRAIN)]
Expand All @@ -276,7 +265,7 @@ def prepare_for_refit(self) -> None:
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset(
self.dataset,
split=Split.TRAIN,
test_size=0.5,
test_size=self.config.separation_ratio or 0.5,
random_seed=self.random_seed,
allow_oos_in_train=False,
)
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/abc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]:
:param split: Split to score on
:return: Computed metrics value for the test set or error code of metrics
"""
if context.data_handler.scheme == "ho":
if context.data_handler.config.scheme == "ho":
return self.score_ho(context, metrics)
if context.data_handler.scheme == "cv":
if context.data_handler.config.scheme == "cv":
return self.score_cv(context, metrics)
msg = "Something's wrong with validation schemas"
raise RuntimeError(msg)
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/abc/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
chosen_metrics = {name: fn for name, fn in DECISION_METRICS.items() if name in metrics}
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
all_val_decisions = []
for j in range(context.data_handler.n_folds):
for j in range(context.data_handler.config.n_folds):
val_labels = labels[j]
val_scores = scores[j]
train_folds = [i for i in range(context.data_handler.n_folds) if i != j]
train_folds = [i for i in range(context.data_handler.config.n_folds) if i != j]
train_labels = [ut for i_fold in train_folds for ut in labels[i_fold]]
train_scores = np.array([sc for i_fold in train_folds for sc in scores[i_fold]])
self.fit(train_scores, train_labels, context.data_handler.tags) # type: ignore[arg-type]
Expand Down
5 changes: 3 additions & 2 deletions tests/data/test_data_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from autointent import Dataset
from autointent.configs import DataConfig
from autointent.context.data_handler import DataHandler
from autointent.schemas import Sample

Expand Down Expand Up @@ -180,7 +181,7 @@ def count_oos(split):


def test_cv_folding(dataset):
DataHandler(dataset, scheme="cv", n_folds=3)
DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3))

desired_specs = {
"test": {"total": 12, "oos": 4},
Expand All @@ -199,7 +200,7 @@ def count_oos_labels(split):


def test_cv_iterator(dataset):
dh = DataHandler(dataset, scheme="cv", n_folds=3)
dh = DataHandler(dataset, config=DataConfig(scheme="cv", n_folds=3))

desired_specs = [
{
Expand Down