Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import yaml

from autointent import Context, Dataset
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType
from autointent.metrics import DECISION_METRICS
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
Expand Down Expand Up @@ -43,11 +43,12 @@ def __init__(
if isinstance(nodes[0], NodeOptimizer):
self.logging_config = LoggingConfig(dump_dir=None)
self.vector_index_config = VectorIndexConfig()
self.data_config = DataConfig()
elif not isinstance(nodes[0], InferenceNode):
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
raise TypeError(msg)

def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None:
"""
Set configuration for the optimizer.

Expand All @@ -57,6 +58,8 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
self.logging_config = config
elif isinstance(config, VectorIndexConfig):
self.vector_index_config = config
elif isinstance(config, DataConfig):
self.data_config = config
else:
msg = "unknown config type"
raise TypeError(msg)
Expand Down Expand Up @@ -119,8 +122,6 @@ def _is_inference(self) -> bool:
def fit(
self,
dataset: Dataset,
scheme: ValidationScheme = "ho",
n_folds: int = 3,
refit_after: bool = False,
sampler: SamplerType = "brute",
) -> Context:
Expand All @@ -135,7 +136,7 @@ def fit(
raise RuntimeError(msg)

context = Context()
context.set_dataset(dataset, scheme, n_folds)
context.set_dataset(dataset, self.data_config)
context.configure_logging(self.logging_config)
context.configure_vector_index(self.vector_index_config)

Expand Down
6 changes: 3 additions & 3 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
class DataConfig(BaseModel):
"""Configuration for the data used in the optimization process."""

train_path: str | Path
"""Path to the training data. Can be local path or HF repo."""
scheme: ValidationScheme
scheme: ValidationScheme = "ho"
"""Hold-out or cross-validation."""
n_folds: PositiveInt = 3
"""Number of folds in cross-validation."""
separate_nodes: bool = True
"""Whether to use separate data for decision node."""


class TaskConfig(BaseModel):
Expand Down
22 changes: 3 additions & 19 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
LoggingConfig,
VectorIndexConfig,
)
from autointent.custom_types import ValidationScheme

from ._utils import NumpyEncoder, load_dataset
from ._utils import NumpyEncoder
from .data_handler import DataHandler
from .optimization_info import OptimizationInfo

Expand Down Expand Up @@ -60,28 +59,13 @@ def configure_vector_index(self, config: VectorIndexConfig) -> None:
"""
self.vector_index_config = config

def configure_data(self, config: DataConfig) -> None:
"""
Configure data handling.

:param config: Configuration for the data handling process.
"""
self.data_handler = DataHandler(
dataset=load_dataset(config.train_path), random_seed=self.seed, scheme=config.scheme
)

def set_dataset(self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3) -> None:
def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
"""
Set the datasets for training, validation and testing.

:param dataset: Dataset.
"""
self.data_handler = DataHandler(
dataset=dataset,
random_seed=self.seed,
scheme=scheme,
n_folds=n_folds,
)
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump())

def get_inference_config(self) -> dict[str, Any]:
"""
Expand Down
95 changes: 25 additions & 70 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
dataset: Dataset,
scheme: ValidationScheme = "ho",
split_train: bool = True,
separate_nodes: bool = True,
random_seed: int = 0,
n_folds: int = 3,
) -> None:
Expand All @@ -42,7 +42,7 @@ def __init__(

:param dataset: Training dataset.
:param random_seed: Seed for random number generation.
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
:param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and
threshold search).
"""
set_seed(random_seed)
Expand All @@ -55,7 +55,7 @@ def __init__(
self.n_folds = n_folds

if scheme == "ho":
self._split_ho(split_train)
self._split_ho(separate_nodes)
elif scheme == "cv":
self._split_cv()

Expand All @@ -82,6 +82,15 @@ def multilabel(self) -> bool:
"""
return self.dataset.multilabel

def _choose_split(self, split_name: str, idx: int | None = None) -> str:
if idx is not None:
split = f"{split_name}_{idx}"
if split not in self.dataset:
split = split_name
else:
split = split_name
return split

def train_utterances(self, idx: int | None = None) -> list[str]:
"""
Retrieve training utterances from the dataset.
Expand All @@ -93,7 +102,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]:
:param idx: Optional index for a specific training split.
:return: List of training utterances.
"""
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
split = self._choose_split(Split.TRAIN, idx)
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])

def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
Expand All @@ -107,7 +116,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels:
:param idx: Optional index for a specific training split.
:return: List of training labels.
"""
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
split = self._choose_split(Split.TRAIN, idx)
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])

def train_labels_folded(self) -> list[ListOfGenericLabels]:
Expand All @@ -124,7 +133,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]:
:param idx: Optional index for a specific validation split.
:return: List of validation utterances.
"""
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
split = self._choose_split(Split.VALIDATION, idx)
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])

def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
Expand All @@ -138,10 +147,10 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
:param idx: Optional index for a specific validation split.
:return: List of validation labels.
"""
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
split = self._choose_split(Split.VALIDATION, idx)
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])

def test_utterances(self, idx: int | None = None) -> list[str]:
def test_utterances(self) -> list[str]:
"""
Retrieve test utterances from the dataset.

Expand All @@ -152,10 +161,9 @@ def test_utterances(self, idx: int | None = None) -> list[str]:
:param idx: Optional index for a specific test split.
:return: List of test utterances.
"""
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature])

def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
def test_labels(self) -> ListOfGenericLabels:
"""
Retrieve test labels from the dataset.

Expand All @@ -166,8 +174,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels:
:param idx: Optional index for a specific test split.
:return: List of test labels.
"""
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature])

def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
if self.scheme == "ho":
Expand All @@ -186,27 +193,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
train_labels = [lab for lab in train_labels if lab is not None]
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]

def _split_ho(self, split_train: bool) -> None:
def _split_ho(self, separate_nodes: bool) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)

if split_train and Split.TRAIN in self.dataset:
if separate_nodes and Split.TRAIN in self.dataset:
self._split_train()

if Split.TEST not in self.dataset:
test_size = 0.1 if has_validation_split else 0.2
self._split_test(test_size)

if not has_validation_split:
self._split_validation_from_train()
elif Split.VALIDATION in self.dataset:
self._split_validation()

for split in self.dataset:
n_classes_split = self.dataset.get_n_classes(split)
if n_classes_split != self.n_classes:
n_classes_in_split = self.dataset.get_n_classes(split)
if n_classes_in_split != self.n_classes:
message = (
f"Number of classes in split '{split}' doesn't match initial number of classes "
f"({n_classes_split} != {self.n_classes})"
f"{n_classes_in_split=} for '{split=}' doesn't match initial number of classes ({self.n_classes})"
)
raise ValueError(message)

Expand All @@ -225,30 +225,6 @@ def _split_train(self) -> None:
)
self.dataset.pop(Split.TRAIN)

def _split_validation(self) -> None:
"""
Split on two sets.

One is for scoring node optimizaton, one is for decision node.
"""
self.dataset[f"{Split.VALIDATION}_0"], self.dataset[f"{Split.VALIDATION}_1"] = split_dataset(
self.dataset,
split=Split.VALIDATION,
test_size=0.5,
random_seed=self.random_seed,
allow_oos_in_train=False, # only val data for decision node should contain OOS
)
self.dataset.pop(Split.VALIDATION)

def _split_validation_from_test(self) -> None:
self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=Split.TEST,
test_size=0.5,
random_seed=self.random_seed,
allow_oos_in_train=True, # both test and validation splits can contain OOS
)

def _split_cv(self) -> None:
extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]]
if extra_splits:
Expand Down Expand Up @@ -290,27 +266,6 @@ def _split_validation_from_train(self) -> None:
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
)

def _split_test(self, test_size: float) -> None:
"""Obtain test set from train."""
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_0",
test_size=test_size,
random_seed=self.random_seed,
)
self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_1",
test_size=test_size,
random_seed=self.random_seed,
allow_oos_in_train=True,
)
self.dataset[Split.TEST] = concatenate_datasets(
[self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]],
)
self.dataset.pop(f"{Split.TEST}_0")
self.dataset.pop(f"{Split.TEST}_1")

def prepare_for_refit(self) -> None:
if self.scheme == "ho":
return
Expand Down
15 changes: 15 additions & 0 deletions tests/assets/configs/light.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
- node_type: embedding
target_metric: retrieval_hit_rate
search_space:
- module_name: retrieval
k: [10]
embedder_config:
- model_name: sentence-transformers/all-MiniLM-L6-v2
- node_type: scoring
target_metric: scoring_roc_auc
search_space:
- module_name: linear
- node_type: decision
target_metric: decision_accuracy
search_space:
- module_name: argmax
Loading
Loading