diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 5d28c142b..8468185a3 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -9,8 +9,8 @@ import yaml from autointent import Context, Dataset -from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig -from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme +from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig +from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType from autointent.metrics import DECISION_METRICS from autointent.nodes import InferenceNode, NodeOptimizer from autointent.nodes.schemes import OptimizationConfig @@ -43,11 +43,12 @@ def __init__( if isinstance(nodes[0], NodeOptimizer): self.logging_config = LoggingConfig(dump_dir=None) self.vector_index_config = VectorIndexConfig() + self.data_config = DataConfig() elif not isinstance(nodes[0], InferenceNode): msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes" raise TypeError(msg) - def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None: + def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None: """ Set configuration for the optimizer. @@ -57,6 +58,8 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None: self.logging_config = config elif isinstance(config, VectorIndexConfig): self.vector_index_config = config + elif isinstance(config, DataConfig): + self.data_config = config else: msg = "unknown config type" raise TypeError(msg) @@ -119,8 +122,6 @@ def _is_inference(self) -> bool: def fit( self, dataset: Dataset, - scheme: ValidationScheme = "ho", - n_folds: int = 3, refit_after: bool = False, sampler: SamplerType = "brute", ) -> Context: @@ -135,7 +136,7 @@ def fit( raise RuntimeError(msg) context = Context() - context.set_dataset(dataset, scheme, n_folds) + context.set_dataset(dataset, self.data_config) context.configure_logging(self.logging_config) context.configure_vector_index(self.vector_index_config) diff --git a/autointent/configs/_optimization.py b/autointent/configs/_optimization.py index 8d925f8ae..b4b8370b5 100644 --- a/autointent/configs/_optimization.py +++ b/autointent/configs/_optimization.py @@ -12,12 +12,12 @@ class DataConfig(BaseModel): """Configuration for the data used in the optimization process.""" - train_path: str | Path - """Path to the training data. Can be local path or HF repo.""" - scheme: ValidationScheme + scheme: ValidationScheme = "ho" """Hold-out or cross-validation.""" n_folds: PositiveInt = 3 """Number of folds in cross-validation.""" + separate_nodes: bool = True + """Whether to use separate data for decision node.""" class TaskConfig(BaseModel): diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 246f05b72..75de47f83 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -14,9 +14,8 @@ LoggingConfig, VectorIndexConfig, ) -from autointent.custom_types import ValidationScheme -from ._utils import NumpyEncoder, load_dataset +from ._utils import NumpyEncoder from .data_handler import DataHandler from .optimization_info import OptimizationInfo @@ -60,28 +59,13 @@ def configure_vector_index(self, config: VectorIndexConfig) -> None: """ self.vector_index_config = config - def configure_data(self, config: DataConfig) -> None: - """ - Configure data handling. - - :param config: Configuration for the data handling process. - """ - self.data_handler = DataHandler( - dataset=load_dataset(config.train_path), random_seed=self.seed, scheme=config.scheme - ) - - def set_dataset(self, dataset: Dataset, scheme: ValidationScheme = "ho", n_folds: int = 3) -> None: + def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: """ Set the datasets for training, validation and testing. :param dataset: Dataset. """ - self.data_handler = DataHandler( - dataset=dataset, - random_seed=self.seed, - scheme=scheme, - n_folds=n_folds, - ) + self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, **config.model_dump()) def get_inference_config(self) -> dict[str, Any]: """ diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index be8cb59c4..507d8e0f5 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -33,7 +33,7 @@ def __init__( self, dataset: Dataset, scheme: ValidationScheme = "ho", - split_train: bool = True, + separate_nodes: bool = True, random_seed: int = 0, n_folds: int = 3, ) -> None: @@ -42,7 +42,7 @@ def __init__( :param dataset: Training dataset. :param random_seed: Seed for random number generation. - :param split_train: Perform or not splitting of train (default to split to be used in scoring and + :param separate_nodes: Perform or not splitting of train (default to split to be used in scoring and threshold search). """ set_seed(random_seed) @@ -55,7 +55,7 @@ def __init__( self.n_folds = n_folds if scheme == "ho": - self._split_ho(split_train) + self._split_ho(separate_nodes) elif scheme == "cv": self._split_cv() @@ -82,6 +82,15 @@ def multilabel(self) -> bool: """ return self.dataset.multilabel + def _choose_split(self, split_name: str, idx: int | None = None) -> str: + if idx is not None: + split = f"{split_name}_{idx}" + if split not in self.dataset: + split = split_name + else: + split = split_name + return split + def train_utterances(self, idx: int | None = None) -> list[str]: """ Retrieve training utterances from the dataset. @@ -93,7 +102,7 @@ def train_utterances(self, idx: int | None = None) -> list[str]: :param idx: Optional index for a specific training split. :return: List of training utterances. """ - split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN + split = self._choose_split(Split.TRAIN, idx) return cast(list[str], self.dataset[split][self.dataset.utterance_feature]) def train_labels(self, idx: int | None = None) -> ListOfGenericLabels: @@ -107,7 +116,7 @@ def train_labels(self, idx: int | None = None) -> ListOfGenericLabels: :param idx: Optional index for a specific training split. :return: List of training labels. """ - split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN + split = self._choose_split(Split.TRAIN, idx) return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) def train_labels_folded(self) -> list[ListOfGenericLabels]: @@ -124,7 +133,7 @@ def validation_utterances(self, idx: int | None = None) -> list[str]: :param idx: Optional index for a specific validation split. :return: List of validation utterances. """ - split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION + split = self._choose_split(Split.VALIDATION, idx) return cast(list[str], self.dataset[split][self.dataset.utterance_feature]) def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels: @@ -138,10 +147,10 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels: :param idx: Optional index for a specific validation split. :return: List of validation labels. """ - split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION + split = self._choose_split(Split.VALIDATION, idx) return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) - def test_utterances(self, idx: int | None = None) -> list[str]: + def test_utterances(self) -> list[str]: """ Retrieve test utterances from the dataset. @@ -152,10 +161,9 @@ def test_utterances(self, idx: int | None = None) -> list[str]: :param idx: Optional index for a specific test split. :return: List of test utterances. """ - split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST - return cast(list[str], self.dataset[split][self.dataset.utterance_feature]) + return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature]) - def test_labels(self, idx: int | None = None) -> ListOfGenericLabels: + def test_labels(self) -> ListOfGenericLabels: """ Retrieve test labels from the dataset. @@ -166,8 +174,7 @@ def test_labels(self, idx: int | None = None) -> ListOfGenericLabels: :param idx: Optional index for a specific test split. :return: List of test labels. """ - split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST - return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) + return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature]) def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]: if self.scheme == "ho": @@ -186,27 +193,20 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s train_labels = [lab for lab in train_labels if lab is not None] yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc] - def _split_ho(self, split_train: bool) -> None: + def _split_ho(self, separate_nodes: bool) -> None: has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset) - if split_train and Split.TRAIN in self.dataset: + if separate_nodes and Split.TRAIN in self.dataset: self._split_train() - if Split.TEST not in self.dataset: - test_size = 0.1 if has_validation_split else 0.2 - self._split_test(test_size) - if not has_validation_split: self._split_validation_from_train() - elif Split.VALIDATION in self.dataset: - self._split_validation() for split in self.dataset: - n_classes_split = self.dataset.get_n_classes(split) - if n_classes_split != self.n_classes: + n_classes_in_split = self.dataset.get_n_classes(split) + if n_classes_in_split != self.n_classes: message = ( - f"Number of classes in split '{split}' doesn't match initial number of classes " - f"({n_classes_split} != {self.n_classes})" + f"{n_classes_in_split=} for '{split=}' doesn't match initial number of classes ({self.n_classes})" ) raise ValueError(message) @@ -225,30 +225,6 @@ def _split_train(self) -> None: ) self.dataset.pop(Split.TRAIN) - def _split_validation(self) -> None: - """ - Split on two sets. - - One is for scoring node optimizaton, one is for decision node. - """ - self.dataset[f"{Split.VALIDATION}_0"], self.dataset[f"{Split.VALIDATION}_1"] = split_dataset( - self.dataset, - split=Split.VALIDATION, - test_size=0.5, - random_seed=self.random_seed, - allow_oos_in_train=False, # only val data for decision node should contain OOS - ) - self.dataset.pop(Split.VALIDATION) - - def _split_validation_from_test(self) -> None: - self.dataset[Split.TEST], self.dataset[Split.VALIDATION] = split_dataset( - self.dataset, - split=Split.TEST, - test_size=0.5, - random_seed=self.random_seed, - allow_oos_in_train=True, # both test and validation splits can contain OOS - ) - def _split_cv(self) -> None: extra_splits = [split_name for split_name in self.dataset if split_name not in [Split.TRAIN, Split.TEST]] if extra_splits: @@ -290,27 +266,6 @@ def _split_validation_from_train(self) -> None: allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train ) - def _split_test(self, test_size: float) -> None: - """Obtain test set from train.""" - self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset( - self.dataset, - split=f"{Split.TRAIN}_0", - test_size=test_size, - random_seed=self.random_seed, - ) - self.dataset[f"{Split.TRAIN}_1"], self.dataset[f"{Split.TEST}_1"] = split_dataset( - self.dataset, - split=f"{Split.TRAIN}_1", - test_size=test_size, - random_seed=self.random_seed, - allow_oos_in_train=True, - ) - self.dataset[Split.TEST] = concatenate_datasets( - [self.dataset[f"{Split.TEST}_0"], self.dataset[f"{Split.TEST}_1"]], - ) - self.dataset.pop(f"{Split.TEST}_0") - self.dataset.pop(f"{Split.TEST}_1") - def prepare_for_refit(self) -> None: if self.scheme == "ho": return diff --git a/tests/assets/configs/light.yaml b/tests/assets/configs/light.yaml new file mode 100644 index 000000000..f36c2606c --- /dev/null +++ b/tests/assets/configs/light.yaml @@ -0,0 +1,15 @@ +- node_type: embedding + target_metric: retrieval_hit_rate + search_space: + - module_name: retrieval + k: [10] + embedder_config: + - model_name: sentence-transformers/all-MiniLM-L6-v2 +- node_type: scoring + target_metric: scoring_roc_auc + search_space: + - module_name: linear +- node_type: decision + target_metric: decision_accuracy + search_space: + - module_name: argmax diff --git a/tests/assets/data/clinc_no_oos.json b/tests/assets/data/clinc_no_oos.json new file mode 100644 index 000000000..4c07a65f0 --- /dev/null +++ b/tests/assets/data/clinc_no_oos.json @@ -0,0 +1,202 @@ +{ + "train": [ + { + "utterance": "do they take reservations at mcdonalds", + "label": 0 + }, + { + "utterance": "i would like an update on the progress of my credit card application", + "label": 3 + }, + { + "utterance": "can you tell me why is my bank account frozen", + "label": 1 + }, + { + "utterance": "why in the world am i locked out of my bank account", + "label": 1 + }, + { + "utterance": "who froze my bank account", + "label": 1 + }, + { + "utterance": "has my amex application gone through yet", + "label": 3 + }, + { + "utterance": "does michael's accept reservations", + "label": 0 + }, + { + "utterance": "why cannot i take any money out from my bank account", + "label": 1 + }, + { + "utterance": "does cowgirl creamery in san francisco take reservations", + "label": 0 + }, + { + "utterance": "why am i locked out of my bank account", + "label": 1 + }, + { + "utterance": "find out if la tour d'argent in paris takes reservations", + "label": 0 + }, + { + "utterance": "please set two alarms, one at 12 pm and the next at 1 pm", + "label": 2 + }, + { + "utterance": "will i be notified when my application has been processed", + "label": 3 + }, + { + "utterance": "wake me up at noon tomorrow", + "label": 2 + }, + { + "utterance": "was my application approved or not for a credit card at chase bank", + "label": 3 + }, + { + "utterance": "set my alarm for getting up", + "label": 2 + }, + { + "utterance": "has there been any notice that my card app has been looked at", + "label": 3 + }, + { + "utterance": "i need you to schedule an alarm", + "label": 2 + }, + { + "utterance": "why am i seeing a hold on my boa account", + "label": 1 + }, + { + "utterance": "set up an alarm", + "label": 2 + }, + { + "utterance": "did my american express card application go through yet", + "label": 3 + }, + { + "utterance": "please create an alarm for 12 noon", + "label": 2 + }, + { + "utterance": "does moes in la except rerservations", + "label": 0 + }, + { + "utterance": "do you know whether ihop does reservations", + "label": 0 + } + ], + "validation": [ + { + "utterance": "how come a hold was placed on my 401k account", + "label": 1 + }, + { + "utterance": "what is the status of my new credit card application", + "label": 3 + }, + { + "utterance": "do they take reservations at carrabbas", + "label": 0 + }, + { + "utterance": "make an alarm for tomorrow at twilight", + "label": 2 + }, + { + "utterance": "where can i find out the status of my credit card application", + "label": 3 + }, + { + "utterance": "is it possible to make reservations with famous dave's restaurant", + "label": 0 + }, + { + "utterance": "what's with the block on my bank account", + "label": 1 + }, + { + "utterance": "set alarm for 5 am", + "label": 2 + } + ], + "test": [ + { + "utterance": "does pho king in ceres take reservations", + "label": 0 + }, + { + "utterance": "make an alarm 6am", + "label": 2 + }, + { + "utterance": "i need you to set an alarm for 8am tomorrow", + "label": 2 + }, + { + "utterance": "does the steakhouse on main st take reservations", + "label": 0 + }, + { + "utterance": "has there been any changes in the status of my credit card application", + "label": 3 + }, + { + "utterance": "look to see if my application for the barclay's card has gone through yet", + "label": 3 + }, + { + "utterance": "can i talk to someone about why there is a hold on my checking account", + "label": 1 + }, + { + "utterance": "help me please, my account is blocked", + "label": 1 + } + ], + "intents": [ + { + "id": 0, + "name": null, + "tags": [], + "regexp_full_match": [], + "regexp_partial_match": [], + "description": null + }, + { + "id": 1, + "name": null, + "tags": [], + "regexp_full_match": [], + "regexp_partial_match": [], + "description": null + }, + { + "id": 2, + "name": null, + "tags": [], + "regexp_full_match": [], + "regexp_partial_match": [], + "description": null + }, + { + "id": 3, + "name": null, + "tags": [], + "regexp_full_match": [], + "regexp_partial_match": [], + "description": null + } + ] +} \ No newline at end of file diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index c9931a6b7..6cb03aa8f 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -5,7 +5,7 @@ from autointent import Context, Pipeline from autointent._callbacks import CallbackHandler, OptimizerCallback -from autointent.configs import LoggingConfig, VectorIndexConfig +from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig from tests.conftest import setup_environment @@ -86,7 +86,7 @@ def test_pipeline_callbacks(dataset): context.configure_vector_index(VectorIndexConfig(save_db=True)) context.configure_logging(LoggingConfig(run_name="dummy_run_name", project_dir=project_dir, dump_modules=False)) context.callback_handler = CallbackHandler([DummyCallback]) - context.set_dataset(dataset) + context.set_dataset(dataset, DataConfig(scheme="ho", separate_nodes=True)) pipeline_optimizer._fit(context) diff --git a/tests/conftest.py b/tests/conftest.py index 002812907..fb1ed3a4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,13 @@ def dataset_unsplitted(): return Dataset.from_json(path) -TaskType = Literal["multiclass", "multilabel", "description", "optuna"] +@pytest.fixture +def dataset_no_oos(): + path = ires.files("tests.assets.data").joinpath("clinc_no_oos.json") + return Dataset.from_json(path) + + +TaskType = Literal["multiclass", "multilabel", "description", "optuna", "light"] def get_search_space_path(task_type: TaskType): diff --git a/tests/nodes/conftest.py b/tests/nodes/conftest.py index c5669498d..c1d572a4c 100644 --- a/tests/nodes/conftest.py +++ b/tests/nodes/conftest.py @@ -1,10 +1,7 @@ import pytest from autointent import Context, Dataset -from autointent.configs import ( - LoggingConfig, - VectorIndexConfig, -) +from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig from autointent.nodes import NodeOptimizer from tests.conftest import get_dataset_path, setup_environment @@ -79,7 +76,7 @@ def get_context(multilabel): dataset = Dataset.from_json(get_dataset_path()) if multilabel: dataset = dataset.to_multilabel() - res.set_dataset(dataset) + res.set_dataset(dataset, DataConfig(scheme="ho", separate_nodes=True)) res.configure_logging(LoggingConfig(project_dir=project_dir, dump_modules=True)) res.configure_vector_index(VectorIndexConfig()) return res diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index de34bac5d..310db986b 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -3,13 +3,23 @@ import pytest from autointent import Pipeline -from autointent.configs import ( - LoggingConfig, - VectorIndexConfig, -) +from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig from tests.conftest import get_search_space, setup_environment +def test_no_node_separation(dataset_no_oos): + project_dir = setup_environment() + search_space = get_search_space("light") + + pipeline_optimizer = Pipeline.from_search_space(search_space) + + pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) + pipeline_optimizer.set_config(VectorIndexConfig()) + pipeline_optimizer.set_config(DataConfig(scheme="ho", separate_nodes=False)) + + pipeline_optimizer.fit(dataset_no_oos, refit_after=False) + + @pytest.mark.parametrize( "sampler", ["tpe", "random"], @@ -22,8 +32,9 @@ def test_bayes(dataset, sampler): pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) pipeline_optimizer.set_config(VectorIndexConfig()) + pipeline_optimizer.set_config(DataConfig(scheme="ho", separate_nodes=True)) - pipeline_optimizer.fit(dataset, scheme="ho", refit_after=False, sampler=sampler) + pipeline_optimizer.fit(dataset, refit_after=False, sampler=sampler) @pytest.mark.parametrize( @@ -38,11 +49,12 @@ def test_cv(dataset, task_type): pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) pipeline_optimizer.set_config(VectorIndexConfig()) + pipeline_optimizer.set_config(DataConfig(scheme="cv", separate_nodes=True)) if task_type == "multilabel": dataset = dataset.to_multilabel() - context = pipeline_optimizer.fit(dataset, scheme="cv", refit_after=True) + context = pipeline_optimizer.fit(dataset, refit_after=True) context.dump() assert os.listdir(pipeline_optimizer.logging_config.dump_dir) @@ -60,6 +72,7 @@ def test_no_context_optimization(dataset, task_type): pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False)) pipeline_optimizer.set_config(VectorIndexConfig(save_db=True)) + pipeline_optimizer.set_config(DataConfig(scheme="ho", separate_nodes=True)) if task_type == "multilabel": dataset = dataset.to_multilabel()