diff --git a/autointent/_callbacks/__init__.py b/autointent/_callbacks/__init__.py index 376b12b46..35c1d6abc 100644 --- a/autointent/_callbacks/__init__.py +++ b/autointent/_callbacks/__init__.py @@ -11,11 +11,13 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler: - """ - Get the list of callbacks. + """Get the list of callbacks. + + Args: + reporters: List of reporters to use. - :param reporters: List of reporters to use. - :return: Callback handler. + Returns: + CallbackHandler: Callback handler. """ if not reporters: return CallbackHandler() diff --git a/autointent/_callbacks/base.py b/autointent/_callbacks/base.py index 01bd073c8..b021ad472 100644 --- a/autointent/_callbacks/base.py +++ b/autointent/_callbacks/base.py @@ -17,37 +17,37 @@ def __init__(self) -> None: @abstractmethod def start_run(self, run_name: str, dirpath: Path) -> None: - """ - Start a new run. + """Start a new run. - :param run_name: Name of the run. - :param dirpath: Path to the directory where the logs will be saved. + Args: + run_name: Name of the run. + dirpath: Path to the directory where the logs will be saved. """ @abstractmethod def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: - """ - Start a new module. + """Start a new module. - :param module_name: Name of the module. - :param num: Number of the module. - :param module_kwargs: Module parameters. + Args: + module_name: Name of the module. + num: Number of the module. + module_kwargs: Module parameters. """ @abstractmethod def log_value(self, **kwargs: dict[str, Any]) -> None: - """ - Log data. + """Log data. - :param kwargs: Data to log. + Args: + kwargs: Data to log. """ @abstractmethod def log_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log metrics during training. + """Log metrics during training. - :param metrics: Metrics to log. + Args: + metrics: Metrics to log. """ @abstractmethod @@ -60,8 +60,8 @@ def end_run(self) -> None: @abstractmethod def log_final_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log final metrics. + """Log final metrics. - :param metrics: Final metrics. + Args: + metrics: Final metrics. """ diff --git a/autointent/_callbacks/callback_handler.py b/autointent/_callbacks/callback_handler.py index 6a3d6af65..aff8d1a23 100644 --- a/autointent/_callbacks/callback_handler.py +++ b/autointent/_callbacks/callback_handler.py @@ -10,7 +10,11 @@ class CallbackHandler(OptimizerCallback): callbacks: list[OptimizerCallback] def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None: - """Initialize the callback handler.""" + """Initialize the callback handler. + + Args: + callbacks: List of callback classes. + """ if not callbacks: self.callbacks = [] return @@ -18,37 +22,37 @@ def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> No self.callbacks = [cb() for cb in callbacks] def start_run(self, run_name: str, dirpath: Path) -> None: - """ - Start a new run. + """Start a new run. - :param run_name: Name of the run. - :param dirpath: Path to the directory where the logs will be saved. + Args: + run_name: Name of the run. + dirpath: Path to the directory where the logs will be saved. """ self.call_events("start_run", run_name=run_name, dirpath=dirpath) def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: - """ - Start a new module. + """Start a new module. - :param module_name: Name of the module. - :param num: Number of the module. - :param module_kwargs: Module parameters. + Args: + module_name: Name of the module. + num: Number of the module. + module_kwargs: Module parameters. """ self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs) def log_value(self, **kwargs: dict[str, Any]) -> None: - """ - Log data. + """Log data. - :param kwargs: Data to log. + Args: + kwargs: Data to log. """ self.call_events("log_value", **kwargs) def log_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log metrics during training. + """Log metrics during training. - :param metrics: Metrics to log. + Args: + metrics: Metrics to log. """ self.call_events("log_metrics", metrics=metrics) @@ -61,13 +65,19 @@ def end_run(self) -> None: self.call_events("end_run") def log_final_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log final metrics. + """Log final metrics. - :param metrics: Final metrics. + Args: + metrics: Final metrics. """ self.call_events("log_final_metrics", metrics=metrics) def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401 + """Call events for all callbacks. + + Args: + event: Event name. + kwargs: Event parameters. + """ for callback in self.callbacks: getattr(callback, event)(**kwargs) diff --git a/autointent/_callbacks/tensorboard.py b/autointent/_callbacks/tensorboard.py index 2da29053f..1639d6e45 100644 --- a/autointent/_callbacks/tensorboard.py +++ b/autointent/_callbacks/tensorboard.py @@ -5,16 +5,16 @@ class TensorBoardCallback(OptimizerCallback): - """ - TensorBoard callback. - - This callback logs the optimization process to TensorBoard. - """ + """TensorBoard callback for logging the optimization process.""" name = "tensorboard" def __init__(self) -> None: - """Initialize the callback.""" + """Initializes the TensorBoard callback. + + Attempts to import `torch.utils.tensorboard` first. If unavailable, tries to import `tensorboardX`. + Raises an ImportError if neither are installed. + """ try: from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] @@ -32,22 +32,22 @@ def __init__(self) -> None: raise ImportError(msg) from None def start_run(self, run_name: str, dirpath: Path) -> None: - """ - Start a new run. + """Starts a new run and sets the directory for storing logs. - :param run_name: Name of the run. - :param dirpath: Path to the directory where the logs will be saved. + Args: + run_name: Name of the run. + dirpath: Path to the directory where logs will be saved. """ self.run_name = run_name self.dirpath = dirpath def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: - """ - Start a new module. + """Starts a new module and initializes a TensorBoard writer for it. - :param module_name: Name of the module. - :param num: Number of the module. - :param module_kwargs: Module parameters. + Args: + module_name: Name of the module. + num: Identifier number of the module. + module_kwargs: Dictionary containing module parameters. """ module_run_name = f"{self.run_name}_{module_name}_{num}" log_dir = Path(self.dirpath) / module_run_name @@ -58,10 +58,10 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any] self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call] def log_value(self, **kwargs: dict[str, int | float | Any]) -> None: - """ - Log data. + """Logs scalar or text values. - :param kwargs: Data to log. + Args: + **kwargs: Key-value pairs of data to log. Scalars will be logged as numerical values, others as text. """ for key, value in kwargs.items(): if isinstance(value, int | float): @@ -70,10 +70,10 @@ def log_value(self, **kwargs: dict[str, int | float | Any]) -> None: self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call] def log_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log metrics during training. + """Logs training metrics. - :param metrics: Metrics to log. + Args: + metrics: Dictionary of metrics to log. """ for key, value in metrics.items(): if isinstance(value, int | float): @@ -82,10 +82,13 @@ def log_metrics(self, metrics: dict[str, Any]) -> None: self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call] def log_final_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log final metrics. + """Logs final metrics at the end of training. + + Args: + metrics: Dictionary of final metrics. - :param metrics: Final metrics. + Raises: + RuntimeError: If `start_run` has not been called before logging final metrics. """ if self.module_writer is None: msg = "start_run must be called before log_final_metrics." @@ -101,7 +104,11 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None: self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call] def end_module(self) -> None: - """End a module.""" + """Ends the current module and closes the TensorBoard writer. + + Raises: + RuntimeError: If `start_run` has not been called before ending the module. + """ if self.module_writer is None: msg = "start_run must be called before end_module." raise RuntimeError(msg) @@ -110,4 +117,4 @@ def end_module(self) -> None: self.module_writer.close() # type: ignore[no-untyped-call] def end_run(self) -> None: - pass + """Ends the current run. This method is currently a placeholder.""" diff --git a/autointent/_callbacks/wandb.py b/autointent/_callbacks/wandb.py index cf38c530c..3349e02c9 100644 --- a/autointent/_callbacks/wandb.py +++ b/autointent/_callbacks/wandb.py @@ -6,17 +6,25 @@ class WandbCallback(OptimizerCallback): - """ - Wandb callback. + """Wandb callback for logging the optimization process to Weights & Biases (W&B). + + This callback integrates with W&B to track training runs, log metrics, and store + configurations. - This callback logs the optimization process to W&B. - To specify the project name, set the `WANDB_PROJECT` environment variable. Default is `autointent`. + To specify the project name, set the `WANDB_PROJECT` environment variable. If not set, + it defaults to `autointent`. """ name = "wandb" def __init__(self) -> None: - """Initialize the callback.""" + """Initializes the Wandb callback. + + Ensures that `wandb` is installed before using this callback. + + Raises: + ImportError: If `wandb` is not installed. + """ try: import wandb except ImportError: @@ -26,23 +34,29 @@ def __init__(self) -> None: self.wandb = wandb def start_run(self, run_name: str, dirpath: Path) -> None: - """ - Start a new run. + """Starts a new W&B run. + + Initializes the project name and run group. The directory path argument is not + used in this callback. - :param run_name: Name of the run. - :param dirpath: Path to the directory where the logs will be saved. (Not used for this callback) + Args: + run_name: Name of the run (used as a W&B group). + dirpath: Path to store logs (not utilized in W&B logging). """ self.project_name = os.getenv("WANDB_PROJECT", "autointent") self.group = run_name self.dirpath = dirpath def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: - """ - Start a new module. + """Starts a new module within the W&B logging system. + + This initializes a W&B run with the specified module name, unique identifier, + and configuration parameters. - :param module_name: Name of the module. - :param num: Number of the module. - :param module_kwargs: Module parameters. + Args: + module_name: The name of the module being logged. + num: A numerical identifier for the module instance. + module_kwargs: Dictionary containing module parameters. """ self.wandb.init( project=self.project_name, @@ -52,26 +66,30 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any] ) def log_value(self, **kwargs: dict[str, Any]) -> None: - """ - Log data. + """Logs scalar or textual values to W&B. + + This function logs the provided key-value pairs to W&B. - :param kwargs: Data to log. + Args: + **kwargs: Key-value pairs of data to log. """ self.wandb.log(kwargs) def log_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log metrics during training. + """Logs training metrics to W&B. - :param metrics: Metrics to log. + Args: + metrics: A dictionary containing metric names and values. """ self.wandb.log(metrics) def log_final_metrics(self, metrics: dict[str, Any]) -> None: - """ - Log final metrics. + """Logs final evaluation metrics to W&B. + + A new W&B run named `final_metrics` is created to store the final performance metrics. - :param metrics: Final metrics. + Args: + metrics: A dictionary of final performance metrics. """ self.wandb.init( project=self.project_name, @@ -84,8 +102,14 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None: self.wandb.finish() def end_module(self) -> None: - """End a module.""" + """Ends the current W&B module. + + This closes the W&B run associated with the current module. + """ self.wandb.finish() def end_run(self) -> None: - pass + """Ends the W&B run. + + This method is currently a placeholder and does not perform additional operations. + """ diff --git a/autointent/_dataset/_dataset.py b/autointent/_dataset/_dataset.py index e6b93e19e..12f3fe0d5 100644 --- a/autointent/_dataset/_dataset.py +++ b/autointent/_dataset/_dataset.py @@ -1,4 +1,4 @@ -"""File with Dataset definition.""" +"""Defines the Dataset class and related utilities for handling datasets.""" import json import logging @@ -17,11 +17,11 @@ class Sample(TypedDict): - """ - Typed dictionary representing a dataset sample. + """Represents a sample in the dataset. - :param utterance: The text of the utterance. - :param label: The label associated with the utterance, or None if out-of-scope. + Attributes: + utterance: The text of the utterance. + label: The label associated with the utterance, or None if it is out-of-scope. """ utterance: str @@ -29,12 +29,15 @@ class Sample(TypedDict): class Dataset(dict[str, HFDataset]): - """ - Represents a dataset with associated metadata and utilities for processing. + """Represents a dataset with associated metadata and utilities for processing. + + This class extends a dictionary where the keys represent dataset splits (e.g., 'train', 'test'), + and the values are Hugging Face datasets. - :param args: Positional arguments to initialize the dataset. - :param intents: List of intents associated with the dataset. - :param kwargs: Additional keyword arguments to initialize the dataset. + Attributes: + label_feature: The feature name corresponding to labels in the dataset. + utterance_feature: The feature name corresponding to utterances in the dataset. + has_descriptions: Whether the dataset includes descriptions for intents. """ label_feature = "label" @@ -42,45 +45,46 @@ class Dataset(dict[str, HFDataset]): has_descriptions: bool def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401 - """ - Initialize the dataset. + """Initializes the dataset. - :param args: Positional arguments to initialize the dataset. - :param intents: List of intents associated with the dataset. - :param kwargs: Additional keyword arguments to initialize the dataset. + Args: + *args: Positional arguments used for dataset initialization. + intents: A list of intents associated with the dataset. + **kwargs: Additional keyword arguments used for dataset initialization. """ super().__init__(*args, **kwargs) self.intents = intents - self.has_descriptions = self.validate_descriptions() @property def multilabel(self) -> bool: - """ - Check if the dataset is multilabel. + """Checks if the dataset is multilabel. - :return: True if the dataset is multilabel, False otherwise. + Returns: + True if the dataset supports multilabel classification, False otherwise. """ split = Split.TRAIN if Split.TRAIN in self else f"{Split.TRAIN}_0" return isinstance(self[split].features[self.label_feature], Sequence) @cached_property def n_classes(self) -> int: - """ - Get the number of classes in the training split. + """Returns the number of classes in the dataset. - :return: Number of classes. + Returns: + The number of unique classes in the training split. """ return len(self.intents) @classmethod def from_dict(cls, mapping: dict[str, Any]) -> "Dataset": - """ - Load a dataset from a dictionary mapping. + """Creates a dataset from a dictionary mapping. + + Args: + mapping: A dictionary representation of the dataset. - :param mapping: Dictionary representing the dataset. - :return: Initialized Dataset object. + Returns: + A `Dataset` instance initialized from the dictionary. """ from ._reader import DictReader @@ -88,11 +92,13 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset": @classmethod def from_json(cls, filepath: str | Path) -> "Dataset": - """ - Load a dataset from a JSON file. + """Loads a dataset from a JSON file. + + Args: + filepath: Path to the JSON file. - :param filepath: Path to the JSON file. - :return: Initialized Dataset object. + Returns: + A `Dataset` instance initialized from the JSON file. """ from ._reader import JsonReader @@ -100,11 +106,13 @@ def from_json(cls, filepath: str | Path) -> "Dataset": @classmethod def from_hub(cls, repo_id: str) -> "Dataset": - """ - Load a dataset from a Hugging Face repository. + """Loads a dataset from the Hugging Face Hub. - :param repo_id: ID of the Hugging Face repository. - :return: Initialized Dataset object. + Args: + repo_id: The ID of the Hugging Face repository. + + Returns: + A `Dataset` instance initialized from the Hugging Face dataset repository. """ from ._reader import DictReader @@ -116,30 +124,30 @@ def from_hub(cls, repo_id: str) -> "Dataset": return DictReader().read(mapping) def to_multilabel(self) -> "Dataset": - """ - Convert dataset labels to multilabel format. + """Converts dataset labels to multilabel format. - :return: Self, with labels converted to multilabel. + Returns: + The dataset with labels converted to multilabel format. """ for split_name, split in self.items(): self[split_name] = split.map(self._to_multilabel) return self def to_dict(self) -> dict[str, list[dict[str, Any]]]: - """ - Convert the dataset splits and intents to a dictionary of lists. + """Converts the dataset into a dictionary format. - :return: A dictionary containing dataset splits and intents as lists of dictionaries. + Returns: + A dictionary where the keys are dataset splits and the values are lists of samples. """ mapping = {split_name: split.to_list() for split_name, split in self.items()} mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents] return mapping def to_json(self, filepath: str | Path) -> None: - """ - Save the dataset splits and intents to a JSON file. + """Saves the dataset to a JSON file. - :param filepath: The path to the file where the JSON data will be saved. + Args: + filepath: The file path where the dataset should be saved. """ path = Path(filepath) if not path.parent.exists(): @@ -148,11 +156,11 @@ def to_json(self, filepath: str | Path) -> None: json.dump(self.to_dict(), file, indent=4, ensure_ascii=False) def push_to_hub(self, repo_id: str, private: bool = False) -> None: - """ - Push dataset splits to a Hugging Face repository. + """Uploads the dataset to the Hugging Face Hub. - :param repo_id: ID of the Hugging Face repository. - :param private: Whether the repository is private + Args: + repo_id: The ID of the Hugging Face repository. + private: Whether to make the repository private. """ for split_name, split in self.items(): split.push_to_hub(repo_id, split=split_name, private=private) @@ -162,10 +170,10 @@ def push_to_hub(self, repo_id: str, private: bool = False) -> None: intents.push_to_hub(repo_id, config_name=Split.INTENTS, split=Split.INTENTS) def get_tags(self) -> list[Tag]: - """ - Extract unique tags from the dataset's intents. + """Extracts unique tags from the dataset's intents. - :return: List of tags with their associated intent IDs. + Returns: + A list of `Tag` objects containing unique tag names and associated intent IDs. """ tag_mapping = defaultdict(list) for intent in self.intents: @@ -174,11 +182,13 @@ def get_tags(self) -> list[Tag]: return [Tag(name=tag, intent_ids=intent_ids) for tag, intent_ids in tag_mapping.items()] def get_n_classes(self, split: str) -> int: - """ - Calculate the number of unique classes in a given split. + """Calculates the number of unique classes in a dataset split. + + Args: + split: The dataset split to analyze. - :param split: The split to analyze. - :return: Number of unique classes. + Returns: + The number of unique classes in the split. """ classes = set() for label in self[split][self.label_feature]: @@ -192,11 +202,13 @@ def get_n_classes(self, split: str) -> int: return len(classes) def _to_multilabel(self, sample: Sample) -> Sample: - """ - Convert a sample's label to multilabel format. + """Converts a sample's label to multilabel format. - :param sample: The sample to process. - :return: Sample with label in multilabel format. + Args: + sample: A sample from the dataset. + + Returns: + The sample with its label converted to a multilabel format. """ if isinstance(sample["label"], int): ohe_vector = [0] * self.n_classes @@ -205,16 +217,16 @@ def _to_multilabel(self, sample: Sample) -> Sample: return sample def validate_descriptions(self) -> bool: - """ - Check whether the dataset contains text descriptions for each intent. + """Validates whether all intents in the dataset contain descriptions. - :return: True if all intents have description field + Returns: + True if all intents have descriptions, False otherwise. """ has_any = any(intent.description is not None for intent in self.intents) has_all = all(intent.description is not None for intent in self.intents) if has_any and not has_all: - msg = "Some intents have text descriptions, but some of them not." + msg = "Some intents have text descriptions, but some do not." logger.warning(msg) return has_all diff --git a/autointent/_dataset/_reader.py b/autointent/_dataset/_reader.py index 936316c49..6eeaf9a61 100644 --- a/autointent/_dataset/_reader.py +++ b/autointent/_dataset/_reader.py @@ -12,21 +12,29 @@ class BaseReader(ABC): - """ - Abstract base class for dataset readers. Defines the interface for reading datasets. + """Abstract base class for dataset readers. - Subclasses must implement the `_read` method to specify how the dataset is read. + This class defines the interface for reading datasets from various sources. + Subclasses must implement the `_read` method to specify how a dataset should + be read and processed. - :raises NotImplementedError: If `_read` is not implemented by the subclass. + Raises: + NotImplementedError: If `_read` is not implemented in a subclass. """ def read(self, *args: Any, **kwargs: Any) -> Dataset: # noqa: ANN401 - """ - Read and validate the dataset, converting it to the standard `Dataset` format. + """Reads and validates the dataset, converting it into a standardized `Dataset` object. + + This method first calls the `_read` method (implemented by subclasses) + to retrieve the dataset, then validates it using `DatasetValidator`. + The validated dataset is converted into the standard `Dataset` format. + + Args: + *args: Positional arguments passed to the `_read` method. + **kwargs: Keyword arguments passed to the `_read` method. - :param args: Positional arguments for the `_read` method. - :param kwargs: Keyword arguments for the `_read` method. - :return: A `Dataset` object containing the dataset splits and intents. + Returns: + Dataset: A standardized dataset object containing the dataset splits and intents. """ dataset_reader = DatasetValidator.validate(self._read(*args, **kwargs)) splits = dataset_reader.model_dump(exclude={"intents"}, exclude_defaults=True) @@ -37,41 +45,55 @@ def read(self, *args: Any, **kwargs: Any) -> Dataset: # noqa: ANN401 @abstractmethod def _read(self, *args: Any, **kwargs: Any) -> DatasetReader: # noqa: ANN401 - """ - Abstract method for reading a dataset. + """Abstract method for reading a dataset. + + This method must be implemented by subclasses to define the specific logic + for reading datasets from different sources (e.g., dictionaries, JSON files). - This must be implemented by subclasses to provide specific reading logic. + Args: + *args: Positional arguments for dataset reading. + **kwargs: Keyword arguments for dataset reading. - :param args: Positional arguments for dataset reading. - :param kwargs: Keyword arguments for dataset reading. - :return: A `DatasetReader` instance representing the dataset. + Returns: + DatasetReader: A dataset representation that will be validated and processed. """ ... class DictReader(BaseReader): - """Dataset reader that processes datasets provided as Python dictionaries.""" + """Dataset reader that processes datasets provided as Python dictionaries. + + This reader expects datasets in a dictionary format and validates the dataset + structure before converting it into a standardized `Dataset` object. + """ def _read(self, mapping: dict[str, Any]) -> DatasetReader: - """ - Read a dataset from a dictionary and validate it. + """Reads and validates a dataset from a dictionary. - :param mapping: A dictionary representing the dataset. - :return: A validated `DatasetReader` instance. + Args: + mapping: A dictionary representing the dataset. + + Returns: + DatasetReader: A validated dataset representation. """ return DatasetReader.model_validate(mapping) class JsonReader(BaseReader): - """Dataset reader that processes datasets from JSON files.""" + """Dataset reader that loads and processes datasets from JSON files. + + This reader reads datasets stored as JSON files and validates them before + converting them into a standardized `Dataset` object. + """ def _read(self, filepath: str | Path) -> DatasetReader: - """ - Read a dataset from a JSON file and validate it. + """Reads and validates a dataset from a JSON file. + + Args: + filepath: Path to the JSON file containing the dataset. - :param filepath: Path to the JSON file containing the dataset. - :type filepath: str or Path - :return: A validated `DatasetReader` instance. + Returns: + DatasetReader: A validated dataset representation. """ with Path(filepath).open() as file: return DatasetReader.model_validate(json.load(file)) diff --git a/autointent/_dataset/_validation.py b/autointent/_dataset/_validation.py index 7d6ce6e12..27436dd8b 100644 --- a/autointent/_dataset/_validation.py +++ b/autointent/_dataset/_validation.py @@ -1,4 +1,4 @@ -"""File with definitions of DatasetReader and DatasetValidator.""" +"""File containing definitions of DatasetReader and DatasetValidator for handling dataset operations.""" from pydantic import BaseModel, ConfigDict, model_validator @@ -6,17 +6,17 @@ class DatasetReader(BaseModel): - """ - A class to represent a dataset reader for handling training, validation, and test data. - - :param train: List of samples for training. Defaults to an empty list. - :param train_0: List of samples for scoring module training. Defaults to an empty list. - :param train_1: List of samples for decision module training. Defaults to an empty list. - :param validation: List of samples for validation. Defaults to an empty list. - :param validation_0: List of samples for scoring module validation. Defaults to an empty list. - :param validation_1: List of samples for decision module validation. Defaults to an empty list. - :param test: List of samples for testing. Defaults to an empty list. - :param intents: List of intents associated with the dataset. + """Represents a dataset reader for handling training, validation, and test data splits. + + Attributes: + train: List of samples for training. Defaults to an empty list. + train_0: List of samples for scoring module training. Defaults to an empty list. + train_1: List of samples for decision module training. Defaults to an empty list. + validation: List of samples for validation. Defaults to an empty list. + validation_0: List of samples for scoring module validation. Defaults to an empty list. + validation_1: List of samples for decision module validation. Defaults to an empty list. + test: List of samples for testing. Defaults to an empty list. + intents: List of intents associated with the dataset. """ train: list[Sample] = [] @@ -32,11 +32,13 @@ class DatasetReader(BaseModel): @model_validator(mode="after") def validate_dataset(self) -> "DatasetReader": - """ - Validate the dataset by ensuring intents and data splits are consistent. + """Validates dataset integrity by ensuring consistent data splits and intent mappings. + + Raises: + ValueError: If data splits are inconsistent or intent mappings are incorrect. - :raises ValueError: If intents or samples are not properly validated. - :return: The validated DatasetReader instance. + Returns: + DatasetReader: The validated dataset reader instance. """ if self.train and (self.train_0 or self.train_1): message = "If `train` is provided, `train_0` and `train_1` should be empty." @@ -75,11 +77,13 @@ def validate_dataset(self) -> "DatasetReader": return self def _get_n_classes(self, split: list[Sample]) -> int: - """ - Get the number of classes in a dataset split. + """Determines the number of unique classes in a dataset split. + + Args: + split (list[Sample]): List of samples in a dataset split. - :param split: List of samples in a dataset split (train, validation, or test). - :return: The number of classes. + Returns: + int: The number of unique classes. """ classes = set() for sample in split: @@ -92,7 +96,17 @@ def _get_n_classes(self, split: list[Sample]) -> int: return len(classes) def _validate_classes(self, splits: list[list[Sample]]) -> int: - """Validate that each split has all classes.""" + """Ensures that all dataset splits have the same number of classes. + + Args: + splits (list[list[Sample]]): List of dataset splits. + + Raises: + ValueError: If the number of classes is inconsistent across splits or if no classes are found. + + Returns: + int: The number of unique classes. + """ n_classes = [self._get_n_classes(split) for split in splits] if len(set(n_classes)) != 1: message = ( @@ -106,12 +120,16 @@ def _validate_classes(self, splits: list[list[Sample]]) -> int: return n_classes[0] def _validate_intents(self, n_classes: int) -> "DatasetReader": - """ - Validate the intents by checking their IDs for sequential order. + """Ensures intent IDs are sequential and match the number of classes. + + Args: + n_classes (int): The expected number of classes based on dataset splits. - :param n_classes: The number of classes in the dataset. - :raises ValueError: If intent IDs are not sequential starting from 0. - :return: The DatasetReader instance after validation. + Raises: + ValueError: If intent IDs are not sequential or valid. + + Returns: + DatasetReader: The validated dataset reader instance. """ if not self.intents: self.intents = [Intent(id=idx) for idx in range(n_classes)] @@ -126,12 +144,16 @@ def _validate_intents(self, n_classes: int) -> "DatasetReader": return self def _validate_split(self, split: list[Sample]) -> "DatasetReader": - """ - Validate a dataset split to ensure all sample labels reference valid intent IDs. + """Validate a dataset split to ensure all sample labels reference valid intent IDs. - :param split: List of samples in a dataset split (train, validation, or test). - :raises ValueError: If a sample references an invalid or non-existent intent ID. - :return: The DatasetReader instance after validation. + Args: + split: List of samples in a dataset split. + + Raises: + ValueError: If a sample references an invalid or non-existent intent ID. + + Returns: + DatasetReader: The validated dataset reader instance. """ intent_ids = {intent.id for intent in self.intents} for sample in split: @@ -147,14 +169,16 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader": class DatasetValidator: - """A utility class for validating a DatasetReader instance.""" + """Utility class for validating a DatasetReader instance.""" @staticmethod def validate(dataset_reader: DatasetReader) -> DatasetReader: - """ - Validate a DatasetReader instance. + """Validates a DatasetReader instance. + + Args: + dataset_reader (DatasetReader): The dataset reader instance to validate. - :param dataset_reader: An instance of DatasetReader to validate. - :return: The validated DatasetReader instance. + Returns: + DatasetReader: The validated dataset reader instance. """ return dataset_reader diff --git a/autointent/_dump_tools.py b/autointent/_dump_tools.py index 7e52c5ba2..dc932c850 100644 --- a/autointent/_dump_tools.py +++ b/autointent/_dump_tools.py @@ -36,6 +36,11 @@ class Dumper: @staticmethod def make_subdirectories(path: Path) -> None: + """Make subdirectories for dumping. + + Args: + path: Path to make subdirectories in + """ subdirectories = [ path / Dumper.tags, path / Dumper.embedders, @@ -49,7 +54,12 @@ def make_subdirectories(path: Path) -> None: @staticmethod def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901 - """Dump modules attributes to filestystem.""" + """Dump modules attributes to filestystem. + + Args: + obj: Object to dump + path: Path to dump to + """ attrs: dict[str, ModuleAttributes] = vars(obj) simple_attrs = {} arrays: dict[str, npt.NDArray[Any]] = {} diff --git a/autointent/_embedder.py b/autointent/_embedder.py index ccb53275b..fddd65149 100644 --- a/autointent/_embedder.py +++ b/autointent/_embedder.py @@ -21,17 +21,18 @@ def get_embeddings_path(filename: str) -> Path: - """ - Get the path to the embeddings file. + """Get the path to the embeddings file. This function constructs the full path to an embeddings file stored in a specific directory under the user's home directory. The embeddings file is named based on the provided filename, with the `.npy` extension added. - :param filename: The name of the embeddings file (without extension). + Args: + filename: The name of the embeddings file (without extension). - :return: The full path to the embeddings file. + Returns: + The full path to the embeddings file. """ return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy" @@ -52,8 +53,7 @@ class EmbedderDumpMetadata(TypedDict): class Embedder: - """ - A wrapper for managing embedding models using Sentence Transformers. + """A wrapper for managing embedding models using Sentence Transformers. This class handles initialization, saving, loading, and clearing of embedding models, as well as calculating embeddings for input texts. @@ -63,10 +63,10 @@ class Embedder: dump_dir: Path | None = None def __init__(self, embedder_config: EmbedderConfig) -> None: - """ - Initialize the Embedder. + """Initialize the Embedder. - :param embedder_config: Config of embedder. + Args: + embedder_config: Config of embedder. """ self.model_name = embedder_config.model_name self.device = embedder_config.device @@ -82,10 +82,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None: self.logger = logging.getLogger(__name__) def __hash__(self) -> int: - """ - Compute a hash value for the Embedder. + """Compute a hash value for the Embedder. - :returns: The hash value of the Embedder. + Returns: + The hash value of the Embedder. """ hasher = Hasher() for parameter in self.embedding_model.parameters(): @@ -107,10 +107,10 @@ def delete(self) -> None: shutil.rmtree(self.dump_dir) def dump(self, path: Path) -> None: - """ - Save the embedding model and metadata to disk. + """Save the embedding model and metadata to disk. - :param path: Path to the directory where the model will be saved. + Args: + path: Path to the directory where the model will be saved. """ self.dump_dir = path metadata = EmbedderDumpMetadata( @@ -126,10 +126,10 @@ def dump(self, path: Path) -> None: @classmethod def load(cls, path: Path | str) -> "Embedder": - """ - Load the embedding model and metadata from disk. + """Load the embedding model and metadata from disk. - :param path: Path to the directory where the model is stored. + Args: + path: Path to the directory where the model is stored. """ with (Path(path) / cls.metadata_dict_name).open() as file: metadata: EmbedderDumpMetadata = json.load(file) @@ -145,12 +145,14 @@ def load(cls, path: Path | str) -> "Embedder": ) def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]: - """ - Calculate embeddings for a list of utterances. + """Calculate embeddings for a list of utterances. + + Args: + utterances: List of input texts to calculate embeddings for. + task_type: Type of task for which embeddings are calculated. - :param utterances: List of input texts to calculate embeddings for. - :param task_type: Type of task for which embeddings are calculated. - :return: A numpy array of embeddings. + Returns: + A numpy array of embeddings. """ if self.use_cache: hasher = Hasher() diff --git a/autointent/_hash.py b/autointent/_hash.py index 270984ac3..9c84e9624 100644 --- a/autointent/_hash.py +++ b/autointent/_hash.py @@ -7,8 +7,7 @@ class Hasher: - """ - A class that provides methods for hashing data using xxhash. + """A class that provides methods for hashing data using xxhash. This class supports both a class-level method for generating hashes from any given value, as well as an instance-level method for progressively @@ -16,8 +15,7 @@ class Hasher: """ def __init__(self) -> None: - """ - Initialize the Hasher instance and sets up the internal xxhash state. + """Initialize the Hasher instance and sets up the internal xxhash state. This state will be used for progressively hashing values using the `update` method and obtaining the final digest using `hexdigest`. @@ -26,47 +24,48 @@ def __init__(self) -> None: @classmethod def hash(cls, value: Any) -> int: # noqa: ANN401 - """ - Generate a hash for the given value using xxhash. + """Generate a hash for the given value using xxhash. - :param value: The value to be hashed. This can be any Python object. + Args: + value: The value to be hashed. This can be any Python object. - :return: The resulting hash digest as a hexadecimal string. + Returns: + The resulting hash digest as a hexadecimal string. """ if hasattr(value, "__hash__") and value.__hash__ not in {None, object.__hash__}: return hash(value) return xxhash.xxh64(pickle.dumps(value)).intdigest() def update(self, value: Any) -> None: # noqa: ANN401 - """ - Update the internal hash state with the provided value. + """Update the internal hash state with the provided value. This method will first hash the type of the value, then hash the value itself, and update the internal state accordingly. - :param value: The value to update the hash state with. + Args: + value: The value to be hashed and added to the internal state. """ self._state.update(str(type(value)).encode()) self._state.update(str(self.hash(value)).encode()) def hexdigest(self) -> str: - """ - Return the current hash digest as a hexadecimal string. + """Return the current hash digest as a hexadecimal string. This method should be called after one or more `update` calls to get the final hash result. - :return: The resulting hash digest as a hexadecimal string. + Returns: + The resulting hash digest as a hexadecimal string. """ return self._state.hexdigest() def intdigest(self) -> int: - """ - Return the current hash digest as an integer. + """Return the current hash digest as an integer. This method should be called after one or more `update` calls to get the final hash result. - :return: The resulting hash digest as an integer. + Returns: + The resulting hash digest as an integer. """ return self._state.intdigest() diff --git a/autointent/_logging/setup.py b/autointent/_logging/setup.py index cbbc6ca4b..6de4c80fd 100644 --- a/autointent/_logging/setup.py +++ b/autointent/_logging/setup.py @@ -9,15 +9,15 @@ def setup_logging(level: LogLevel | str, log_filename: Path | str | None = None) -> None: - """ - Set stdout and file handlers for logging autointent internal actions. + """Set stdout and file handlers for logging autointent internal actions. The first parameter affects the logs to the standard output stream. The second parameter is optional. If it is specified, then the "DEBUG" messages are logged to the file, regardless of what is specified by the first parameter. - :param level: one of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" - :param log_to_filepath: specify location of logfile, omit extension as suffix ``.log.jsonl`` will be appended. + Args: + level: one of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" + log_filename: specify location of logfile, omit extension as suffix ``.log.jsonl`` will be appended. """ config_file = ires.files("autointent._logging").joinpath("config.yaml") with config_file.open() as f_in: diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 48c29d00a..096d29224 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -31,7 +31,7 @@ from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput if TYPE_CHECKING: - from autointent.modules.abc import BaseDecision, BaseScorer + from autointent.modules.base import BaseDecision, BaseScorer class Pipeline: @@ -43,12 +43,12 @@ def __init__( sampler: SamplerType = "brute", seed: int = 42, ) -> None: - """ - Initialize the pipeline optimizer. + """Initialize the pipeline optimizer. - :param nodes: list of nodes - :param sampler: sampler type - :param seed: random seed + Args: + nodes: List of nodes. + sampler: Sampler type. + seed: Random seed. """ self._logger = logging.getLogger(__name__) self.nodes = {node.node_type: node for node in nodes} @@ -68,10 +68,10 @@ def __init__( assert_never(nodes) def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None: - """ - Set configuration for the optimizer. + """Set the configuration for the pipeline. - :param config: Configuration + Args: + config: Configuration object. """ if isinstance(config, LoggingConfig): self.logging_config = config @@ -86,11 +86,14 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig @classmethod def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline": - """ - Create pipeline optimizer from dictionary search space. + """Search space to pipeline optimizer. + + Args: + search_space: Search space. + seed: Random seed. - :param search_space: Dictionary config - :param seed: random seed + Returns: + Pipeline optimizer. """ if not isinstance(search_space, list): search_space = load_search_space(search_space) @@ -105,8 +108,7 @@ def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline": @classmethod def from_optimization_config(cls, config: dict[str, Any] | Path | str | OptimizationConfig) -> "Pipeline": - """ - Create pipeline optimizer from optimization config. + """Create pipeline optimizer from optimization config. :param config: Optimization config :return: @@ -133,10 +135,11 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza return pipeline def _fit(self, context: Context, sampler: SamplerType) -> None: - """ - Optimize the pipeline. + """Optimize the pipeline. - :param context: Context + Args: + context: Context object. + sampler: Sampler type. """ self.context = context self._logger.info("starting pipeline optimization...") @@ -151,10 +154,10 @@ def _fit(self, context: Context, sampler: SamplerType) -> None: self.context.callback_handler.end_run() def _is_inference(self) -> bool: - """ - Check the mode in which pipeline is. + """Check the mode in which pipeline is. - :return: True if pipeline is in inference mode, False if in optimization mode. + Returns: + True if pipeline is in inference mode, False otherwise. """ return isinstance(self.nodes[NodeType.scoring], InferenceNode) @@ -165,11 +168,19 @@ def fit( sampler: SamplerType | None = None, incompatible_search_space: SearchSpaceValidationMode = "filter", ) -> Context: - """ - Optimize the pipeline from dataset. + """Optimize the pipeline from dataset. + + Args: + dataset: Dataset for optimization. + refit_after: Whether to refit after optimization. + sampler: Sampler type to use. + incompatible_search_space: How to handle incompatible search space. + + Returns: + Context object. - :param dataset: Dataset for optimization - :return: Context + Raises: + RuntimeError: If pipeline is in inference mode. """ if self._is_inference(): msg = "Pipeline in inference mode cannot be fitted" @@ -219,10 +230,11 @@ def fit( return context def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None: - """ - Validate modules with dataset. + """Validate modules with dataset. - :param dataset: dataset to validate with + Args: + dataset: Dataset for validation. + mode: Validation mode. """ for node in self.nodes.values(): if isinstance(node, NodeOptimizer): @@ -230,43 +242,53 @@ def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> @classmethod def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline": - """ - Create inference pipeline from dictionary config. + """Create inference pipeline from dictionary config. + + Args: + nodes_configs: list of config for nodes - :param nodes_configs: list of dictionary config for nodes - :return: pipeline ready for inference + Returns: + Inference pipeline """ return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs]) @classmethod def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline": - """ - Create inference pipeline from config. + """Create inference pipeline from config. + + Args: + nodes_configs: list of config for nodes - :param nodes_configs: list of config for nodes + Returns: + Inference pipeline """ nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs] return cls(nodes) @classmethod def load(cls, path: str | Path) -> "Pipeline": - """ - Load pipeline in inference mode. + """Load pipeline in inference mode. This method loads fitted modules and tuned hyperparameters. - :path: path to optimization run directory - :return: initialized pipeline, ready for inference + + Args: + path: Path to load + + Returns: + Inference pipeline """ with (Path(path) / "inference_config.yaml").open() as file: inference_dict_config = yaml.safe_load(file) return cls.from_dict_config(inference_dict_config["nodes_configs"]) def predict(self, utterances: list[str]) -> ListOfGenericLabels: - """ - Predict the labels for the utterances. + """Predict the labels for the utterances. - :param utterances: list of utterances - :return: list of predicted labels + Args: + utterances: list of utterances + + Returns: + list of predicted labels """ if not self._is_inference(): msg = "Pipeline in optimization mode cannot perform inference" @@ -279,11 +301,13 @@ def predict(self, utterances: list[str]) -> ListOfGenericLabels: return decision_module.predict(scores) def _refit(self, context: Context) -> None: - """ - Fit pipeline of already selected modules with all train data. + """Fit pipeline of already selected modules with all train data. + + Args: + context: Context object. - :param context: context object to take data from - :return: list of predicted labels + Raises: + RuntimeError: If pipeline is in optimization mode. """ if not self._is_inference(): msg = "Pipeline in optimization mode cannot perform inference" @@ -300,11 +324,12 @@ def _refit(self, context: Context) -> None: decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags) def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput: - """ - Predict the labels for the utterances with metadata. + """Predict the labels for the utterances with metadata. - :param utterances: list of utterances - :return: prediction output + Args: + utterances: list of utterances + Returns: + Inference pipeline output """ if not self._is_inference(): msg = "Pipeline in optimization mode cannot perform inference" @@ -340,12 +365,14 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu def make_report(logs: dict[str, Any], nodes: list[NodeType]) -> str: - """ - Generate a report from optimization logs. + """Generate a report from optimization logs. + + Args: + logs: Logs dictionary. + nodes: List of node types. - :param logs: Logs - :param nodes: Nodes - :return: String report + Returns: + String report. """ ids = [np.argmax(logs["metrics"][node]) for node in nodes] configs = [] diff --git a/autointent/_ranker.py b/autointent/_ranker.py index 7569db4c5..fb72ea50a 100644 --- a/autointent/_ranker.py +++ b/autointent/_ranker.py @@ -1,6 +1,7 @@ -"""Ranker class for cross-encoder-based estimation of meaning closeness. +"""Module for cross-encoder-based meaning closeness estimation using ranking. -Can be used to rank retrieved sentences by meaning closeness to provided utterance. +This module provides functionality for ranking retrieved sentences by meaning closeness +to provided utterances using cross-encoder models. """ import gc @@ -26,6 +27,16 @@ class CrossEncoderMetadata(TypedDict): + """Metadata for CrossEncoder model. + + Attributes: + model_name: Name of the model + train_classifier: Whether to train a classifier + device: Device to use for inference + max_length: Maximum sequence length + batch_size: Batch size for inference + """ + model_name: str train_classifier: bool device: str | None @@ -38,13 +49,17 @@ def construct_samples( labels: list[Any], balancing_factor: int | None = None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - """ - Construct balanced samples of text pairs for training. + """Construct balanced samples of text pairs for training. + + Args: + texts: List of texts to create pairs from + labels: List of labels corresponding to the texts + balancing_factor: Factor for balancing positive and negative samples - :param texts: List of texts to create pairs from. - :param labels: List of labels corresponding to the texts. - :param balancing_factor: Factor for balancing the positive and negative samples. If None, no balancing is applied. - :return: Tuple containing lists of text pairs and their corresponding binary labels. + Returns: + Tuple containing: + - List of text pairs + - List of corresponding binary labels """ samples = [[], []] # type: ignore[var-annotated] @@ -70,35 +85,17 @@ def construct_samples( class Ranker: - r""" - Cross-encoder for NLI. - - In the hart this class uses a SentenceTransformers Ranker model to extract features. - Then it uses either the model's clissifier or our custom trained LogisticRegressionCV - (custom classifier layer in the future) to rank documents using similarity score to the query. - - :ivar cross_encoder: The Ranker model used to extract features. - :ivar batch_size: Batch size for processing text pairs. - :ivar _clf: The trained LogisticRegressionCV classifier. - :ivar model_subdir: Directory for storing the cross-encoder model files. - - Examples - -------- - Creating and fitting the CrossEncoderWithLogreg: - >>> from autointent import Ranker - >>> scorer = Ranker("cross-encoder-model") - >>> utterances = ["What is your name?", "How old are you?"] - >>> labels = [1, 0] - >>> scorer.fit(utterances, labels) - - Predicting probabilities: - >>> test_pairs = [["What is your name?", "Hello!"], ["How old are you?", "What is your age?"]] - >>> probs = scorer.predict(test_pairs) - >>> print(probs) - - Saving and loading the model: - >>> scorer.save("outputs/") - >>> loaded_scorer = Ranker.load("outputs/") + """Cross-encoder for Natural Language Inference (NLI). + + This class uses a SentenceTransformers Ranker model to extract features. + It can use either the model's classifier or a custom trained LogisticRegressionCV + to rank documents using similarity scores to the query. + + Attributes: + cross_encoder: The Ranker model used to extract features + batch_size: Batch size for processing text pairs + _clf: The trained LogisticRegressionCV classifier + model_subdir: Directory for storing cross-encoder model files """ metadata_file_name = "metadata.json" @@ -109,12 +106,11 @@ def __init__( cross_encoder_config: CrossEncoderConfig | str | dict[str, Any], classifier_head: LogisticRegressionCV | None = None, ) -> None: - """ - Initialize the Ranker. + """Initialize the Ranker. - :param cross_encoder_config: Config of the cross-encoder hugging face model name to use. - :param max_length (int, optional): Max length for input sequences for the cross encoder. - :param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly). + Args: + cross_encoder_config: Configuration for the cross-encoder model + classifier_head: Optional pre-trained classifier head """ self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config) self.cross_encoder = st.CrossEncoder( @@ -132,18 +128,24 @@ def __init__( self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook) def _classifier_hook(self, _module, input_tensor, _output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ANN001 + """Hook to capture classifier activations. + + Args: + _module: Module being hooked + input_tensor: Input tensor to the classifier + _output_tensor: Output tensor from the classifier + """ self._activations_list.append(input_tensor[0].cpu().numpy()) @torch.no_grad() def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]: - """ - Extract features or get predictions using the Ranker model. + """Extract features or get predictions using the Ranker model. - If :py:attr:`~train_classifier` is ``True``, return raw activations from - cross-encoder transformer. Otherwise, get predictions from cross-encoder head. + Args: + pairs: List of text pairs - :param pairs: List of text pairs. - :return: Numpy array of extracted features. + Returns: + Array of extracted features or predictions """ if not self.train_classifier: return np.array( @@ -154,20 +156,20 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr ) ) - # put the data through, features will be taken in the hook self.cross_encoder.predict(pairs, batch_size=self.cross_encoder_config.batch_size) - res = np.concatenate(self._activations_list, axis=0) self._activations_list.clear() return res # type: ignore[no-any-return] def _fit(self, pairs: list[tuple[str, str]], labels: ListOfLabels) -> None: - """ - Train the logistic regression model on cross-encoder features. + """Train the logistic regression model on cross-encoder features. + + Args: + pairs: List of text pairs + labels: Binary labels (1 = same class, 0 = different classes) - :param pairs: List of text pairs. - :param labels: Binary labels (1 = same class, 0 = different classes). - :raises ValueError: If the number of pairs and labels do not match. + Raises: + ValueError: If number of pairs and labels don't match """ n_samples = len(pairs) if n_samples != len(labels): @@ -176,33 +178,34 @@ def _fit(self, pairs: list[tuple[str, str]], labels: ListOfLabels) -> None: raise ValueError(msg) features = self._get_features_or_predictions(pairs) - - # TODO: LogisticRegressionCV has class_weight="balanced". Is it better to use it instead of balance_factor in - # construct_samples? clf = LogisticRegressionCV() clf.fit(features, labels) - self._clf = clf def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Construct training samples and train the logistic regression classifier. + """Construct training samples and train the logistic regression classifier. - :param utterances: List of utterances (texts). - :param labels: Intent class labels corresponding to the utterances. + Args: + utterances: List of utterances (texts) + labels: Intent class labels corresponding to the utterances """ if not self.train_classifier: - return # do nothing if the classifier is not to be re-trained + return pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1) self._fit(pairs, labels_) # type: ignore[arg-type] def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]: - """ - Predict probabilities of two utterances having the same intent label. + """Predict probabilities of two utterances having the same intent label. + + Args: + pairs: List of text pairs to classify + + Returns: + Array of probabilities - :param pairs: List of text pairs to classify. - :return: Numpy array of probabilities. + Raises: + ValueError: If classifier is not trained yet """ if self.train_classifier and self._clf is None: msg = "Classifier is not trained yet" @@ -212,7 +215,6 @@ def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]: if self._clf is not None: return np.array(self._clf.predict_proba(features)[:, 1]) - return features def rank( @@ -221,13 +223,15 @@ def rank( query_docs: list[str], top_k: int | None = None, ) -> list[dict[str, Any]]: - """ - Rank documents according to meaning closeness to the query. + """Rank documents according to meaning closeness to the query. + + Args: + query: Reference document + query_docs: List of documents to rank + top_k: Number of documents to return - :param query: The reference document. - :param query_docs: List of documents to rank - :param top_k: how many document to return - :return: array of dictionaries of ranked items. + Returns: + List of dictionaries containing ranked items with scores """ query_doc_pairs = [(query, doc) for doc in query_docs] scores = self.predict(query_doc_pairs) @@ -240,10 +244,10 @@ def rank( return results[:top_k] def save(self, path: str) -> None: - """ - Save the model and classifier to disk. + """Save the model and classifier to disk. - :param path: Directory path to save the model and classifier. + Args: + path: Directory path to save the model and classifier """ dump_dir = Path(path) dump_dir.mkdir(parents=True) @@ -263,11 +267,13 @@ def save(self, path: str) -> None: @classmethod def load(cls, path: Path) -> "Ranker": - """ - Load the model and classifier from disk. + """Load the model and classifier from disk. + + Args: + path: Directory path containing the saved model and classifier - :param path: Directory path containing the saved model and classifier. - :return: Initialized Ranker instance. + Returns: + Initialized Ranker instance """ clf = joblib.load(path / cls.classifier_file_name) @@ -286,6 +292,7 @@ def load(cls, path: Path) -> "Ranker": ) def clear_ram(self) -> None: + """Clear model from RAM and GPU memory.""" self.cross_encoder.model.cpu() del self.cross_encoder gc.collect() diff --git a/autointent/_utils.py b/autointent/_utils.py index 203765ab9..d371ac551 100644 --- a/autointent/_utils.py +++ b/autointent/_utils.py @@ -6,4 +6,11 @@ def _funcs_to_dict(*funcs: T) -> dict[str, T]: + """Convert functions to a dictionary. + + Args: + *funcs: Functions to convert + Returns: + Dictionary of functions + """ return {func.__name__: func for func in funcs} # type: ignore[attr-defined] diff --git a/autointent/_vector_index.py b/autointent/_vector_index.py index 0d3983fe0..41c3b875d 100644 --- a/autointent/_vector_index.py +++ b/autointent/_vector_index.py @@ -33,8 +33,7 @@ class VectorIndexData(TypedDict): class VectorIndex: - """ - A class for managing a vector index using FAISS and embedding models. + """A class for managing a vector index using FAISS and embedding models. This class allows adding, querying, and managing embeddings and their associated labels for efficient nearest neighbor search. @@ -44,10 +43,10 @@ class VectorIndex: _meta_data_file = "metadata.json" def __init__(self, embedder_config: EmbedderConfig) -> None: - """ - Initialize the vector index. + """Initialize the VectorIndex with an embedding model. - :param embedder_config: Config of the embedding model to use. + Args: + embedder_config: Configuration for the embedding model. """ self.embedder = Embedder(embedder_config) @@ -57,11 +56,11 @@ def __init__(self, embedder_config: EmbedderConfig) -> None: self.logger = logging.getLogger(__name__) def add(self, texts: list[str], labels: ListOfLabels) -> None: - """ - Add texts and their corresponding labels to the index. + """Add texts and their corresponding labels to the index. - :param texts: List of input texts. - :param labels: List of labels corresponding to the texts. + Args: + texts: List of input texts. + labels: List of labels corresponding to the texts. """ self.logger.debug("Adding embeddings to vector index %s", self.embedder.model_name) embeddings = self.embedder.embed(texts, TaskTypeEnum.passage) @@ -73,10 +72,10 @@ def add(self, texts: list[str], labels: ListOfLabels) -> None: self.texts.extend(texts) def is_empty(self) -> bool: - """ - Check if the index is empty. + """Check if the index is empty. - :return: True if the index contains no embeddings, False otherwise. + Returns: + True if the index contains no embeddings, False otherwise. """ return len(self.labels) == 0 @@ -96,24 +95,27 @@ def clear_ram(self) -> None: self.texts = [] def _search_by_text(self, texts: list[str], k: int) -> list[list[dict[str, Any]]]: - """ - Search the index using text queries. + """Search the index using text queries. - :param texts: List of input texts to search for. - :param k: Number of nearest neighbors to return. - :return: List of search results for each query. + Args: + texts: List of input texts to search for. + k: Number of nearest neighbors to return. + + Returns: + List of search results for each query. """ query_embedding: npt.NDArray[np.float64] = self.embedder.embed(texts, TaskTypeEnum.query) # type: ignore[assignment] return self._search_by_embedding(query_embedding, k) def _search_by_embedding(self, embedding: npt.NDArray[Any], k: int) -> list[list[dict[str, Any]]]: - """ - Search the index using embedding vectors. + """Search the index using embedding vectors. - :param embedding: 2D array of shape (n_queries, dim_size) representing query embeddings. - :param k: Number of nearest neighbors to return. - :return: List of search results for each query. - :raises ValueError: If the embedding array is not 2D. + Args: + embedding: 2D array of shape (n_queries, dim_size) representing query embeddings. + k: Number of nearest neighbors to return. + + Returns: + List of search results for each query. """ if embedding.ndim != 2: # noqa: PLR2004 msg = "`embedding` should be a 2D array of shape (n_queries, dim_size)" @@ -132,11 +134,13 @@ def _search_by_embedding(self, embedding: npt.NDArray[Any], k: int) -> list[list return results def get_all_embeddings(self) -> npt.NDArray[Any]: - """ - Retrieve all embeddings stored in the index. + """Retrieve all embeddings stored in the index. - :return: Array of all embeddings. - :raises ValueError: If the index has not been created yet. + Returns: + Array of all embeddings. + + Raises: + ValueError: If the index has not been created yet. """ if not hasattr(self, "index"): msg = "Index is not created yet" @@ -144,10 +148,10 @@ def get_all_embeddings(self) -> npt.NDArray[Any]: return self.index.reconstruct_n(0, self.index.ntotal) # type: ignore[no-any-return] def get_all_labels(self) -> ListOfLabels: - """ - Retrieve all labels stored in the index. + """Retrieve all labels stored in the index. - :return: List of all labels. + Returns: + List of all labels. """ return self.labels @@ -156,15 +160,17 @@ def query( queries: list[str] | npt.NDArray[np.float32], k: int, ) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]: - """ - Query the index to retrieve nearest neighbors. - - :param queries: List of text queries or embedding vectors. - :param k: Number of nearest neighbors to return for each query. - :return: A tuple containing: - - `labels`: List of retrieved labels for each query. - - `distances`: Corresponding distances for each neighbor. - - `texts`: Corresponding texts for each neighbor. + """Query the index to retrieve nearest neighbors. + + Args: + queries: List of text queries or embedding vectors. + k: Number of nearest neighbors to return for each query. + + Returns: + A tuple containing: + - `labels`: List of retrieved labels for each query. + - `distances`: Corresponding distances for each neighbor. + - `texts`: Corresponding texts for each neighbor. """ func = self._search_by_text if isinstance(queries[0], str) else self._search_by_embedding all_results = func(queries, k) # type: ignore[arg-type] @@ -176,10 +182,10 @@ def query( return all_labels, all_distances, all_texts def dump(self, dir_path: Path) -> None: - """ - Save the index and associated data to disk. + """Save the index and associated data to disk. - :param dir_path: Directory path to save the data. + Args: + dir_path: Directory path where the data will be stored. """ dir_path.mkdir(parents=True, exist_ok=True) self.dump_dir = dir_path @@ -207,10 +213,16 @@ def load( embedder_batch_size: int | None = None, embedder_use_cache: bool | None = None, ) -> "VectorIndex": - """ - Load the index and associated data from disk. + """Load the index and associated data from disk. + + Args: + dir_path: Directory path where the data is stored. + embedder_device: Device for the embedding model. + embedder_batch_size: Batch size for the embedding model. + embedder_use_cache: Whether to use caching for the embedding model. - :param dir_path: Directory path where the data is stored. + Returns: + VectorIndex instance with loaded data. """ with (dir_path / cls._meta_data_file).open() as file: metadata: VectorIndexMetadata = json.load(file) diff --git a/autointent/configs/_name.py b/autointent/configs/_name.py index 30ae0c0f7..f74638bdf 100644 --- a/autointent/configs/_name.py +++ b/autointent/configs/_name.py @@ -342,10 +342,10 @@ def generate_name() -> str: - """ - Generate a random name for a run. + """Generate a random name for a run. - :return: Random name + Returns: + Random name """ adjective = random.choice(adjectives) noun = random.choice(nouns) @@ -353,11 +353,13 @@ def generate_name() -> str: def get_run_name(run_name: str | None = None) -> str: - """ - Get a run name. + """Get a run name. + + Args: + run_name: Name of the run. - :param run_name: Run name. If None, generate a random name - :return: Run name with a timestamp + Returns: + Run name """ if run_name is None: run_name = generate_name() diff --git a/autointent/configs/_transformers.py b/autointent/configs/_transformers.py index b64343691..81ea3d57f 100644 --- a/autointent/configs/_transformers.py +++ b/autointent/configs/_transformers.py @@ -22,7 +22,11 @@ class STModelConfig(ModelConfig): def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self: """Validate the model configuration. - :param values: Model configuration values. If a string is provided, it is converted to a dictionary. + Args: + values: Model configuration values. If a string is provided, it is converted to a dictionary. + + Returns: + Model configuration. """ if values is None: return cls() # type: ignore[call-arg] @@ -58,7 +62,8 @@ class EmbedderConfig(STModelConfig): def get_prompt_config(self) -> dict[str, str] | None: """Get the prompt config for the given prompt type. - :return: The prompt config for the given prompt type. + Returns: + The prompt config for the given prompt type. """ prompts = {} if self.default_prompt: @@ -78,9 +83,11 @@ def get_prompt_config(self) -> dict[str, str] | None: def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911 """Get the prompt type for the given task type. - :param prompt_type: Task type for which to get the prompt. + Args: + prompt_type: Task type for which to get the prompt. - :return: The prompt for the given task type. + Returns: + The prompt for the given task type. """ if prompt_type is None: return self.default_prompt diff --git a/autointent/context/_context.py b/autointent/context/_context.py index a62d7e1e9..9fb21f7a2 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -17,11 +17,15 @@ class Context: - """ - Context manager for configuring and managing data handling, vector indexing, and optimization. + """Context manager for configuring and managing data handling, vector indexing, and optimization. This class provides methods to set up logging, configure data and vector index components, manage datasets, and retrieve various configurations for inference and optimization. + + Attributes: + data_handler: Handler for managing datasets. + optimization_info: Container for optimization information. + callback_handler: Handler for managing callbacks. """ data_handler: DataHandler @@ -29,29 +33,29 @@ class Context: callback_handler = CallbackHandler() def __init__(self, seed: int = 42) -> None: - """ - Initialize the Context object with a specified random seed. + """Initialize the Context object. - :param seed: Random seed for reproducibility, defaults to 42. + Args: + seed: Random seed for reproducibility. """ self.seed = seed self._logger = logging.getLogger(__name__) def configure_logging(self, config: LoggingConfig) -> None: - """ - Configure logging settings. + """Configure logging settings. - :param config: Logging configuration settings. + Args: + config: Logging configuration settings. """ self.logging_config = config self.callback_handler = get_callbacks(config.report_to) self.optimization_info = OptimizationInfo() def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) -> None: - """ - Configure the vector index client and embedder. + """Configure the vector index client and embedder. - :param config: Configuration for the vector index. + Args: + config: Configuration for the vector index. """ if isinstance(config, EmbedderConfig): self.embedder_config = config @@ -59,18 +63,19 @@ def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) -> self.cross_encoder_config = config def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: - """ - Set the datasets for training, validation and testing. + """Set the datasets for training, validation and testing. - :param dataset: Dataset. + Args: + dataset: Dataset. + config: Data configuration settings. """ self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config) def get_inference_config(self) -> dict[str, Any]: - """ - Generate configuration settings for inference. + """Generate configuration settings for inference. - :return: Dictionary containing inference configuration. + Returns: + Dictionary containing inference configuration. """ nodes_configs = self.optimization_info.get_inference_nodes_config(asdict=True) return { @@ -83,12 +88,7 @@ def get_inference_config(self) -> dict[str, Any]: } def dump(self) -> None: - """ - Save logs, configurations, and datasets to disk. - - Dumps evaluation results, training/test data splits, and inference configurations - to the specified logging directory. - """ + """Save logs, configurations, and datasets to disk.""" self._logger.debug("dumping logs...") optimization_results = self.optimization_info.dump_evaluation_results() @@ -99,9 +99,6 @@ def dump(self) -> None: with logs_path.open("w") as file: json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder) - # self._logger.info(make_report(optimization_results, nodes=nodes)) - - # dump train and test data splits self.data_handler.dataset.to_json(logs_dir / "dataset.json") self._logger.info("logs and other assets are saved to %s", logs_dir) @@ -112,49 +109,57 @@ def dump(self) -> None: yaml.dump(inference_config, file) def get_dump_dir(self) -> Path | None: - """ - Get the directory for saving dumped modules. + """Get the directory for saving dumped modules. - :return: Path to the dump directory or None if dumping is disabled. + Returns: + Path to the dump directory or None if dumping is disabled. """ if self.logging_config.dump_modules: return self.logging_config.dump_dir return None def is_multilabel(self) -> bool: - """ - Check if the dataset is configured for multilabel classification. + """Check if the dataset is configured for multilabel classification. - :return: True if multilabel classification is enabled, False otherwise. + Returns: + True if multilabel classification is enabled, False otherwise. """ return self.data_handler.multilabel def get_n_classes(self) -> int: - """ - Get the number of classes in the dataset. + """Get the number of classes in the dataset. - :return: Number of classes. + Returns: + Number of classes. """ return self.data_handler.n_classes def is_ram_to_clear(self) -> bool: - """ - Check if RAM clearing is enabled in the logging configuration. + """Check if RAM clearing is enabled in the logging configuration. - :return: True if RAM clearing is enabled, False otherwise. + Returns: + True if RAM clearing is enabled, False otherwise. """ return self.logging_config.clear_ram def has_saved_modules(self) -> bool: - """ - Check if any modules have been saved. + """Check if any modules have been saved. - :return: True if there are saved modules, False otherwise. + Returns: + True if there are saved modules, False otherwise. """ node_types = ["regex", "embedding", "scoring", "decision"] return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types) def resolve_embedder(self) -> EmbedderConfig: + """Resolve the embedder configuration. + + Returns: + The best embedder configuration or default configuration. + + Raises: + RuntimeError: If embedder configuration cannot be resolved. + """ try: return self.optimization_info.get_best_embedder() except ValueError as e: @@ -167,6 +172,14 @@ def resolve_embedder(self) -> EmbedderConfig: raise RuntimeError(msg) from e def resolve_ranker(self) -> CrossEncoderConfig: + """Resolve the cross-encoder configuration. + + Returns: + The cross-encoder configuration. + + Raises: + RuntimeError: If cross-encoder configuration cannot be resolved. + """ if hasattr(self, "cross_encoder_config"): return self.cross_encoder_config msg = "Cross-encoder could't be resolved. Set default config with Context.configure_transformer." diff --git a/autointent/context/_utils.py b/autointent/context/_utils.py index 7e8280ce7..3ae8eda84 100644 --- a/autointent/context/_utils.py +++ b/autointent/context/_utils.py @@ -1,8 +1,4 @@ -"""Module for loading datasets and handling JSON serialization with numpy compatibility. - -This module provides utilities for loading datasets and serializing objects -that include numpy data types. -""" +"""Module for loading datasets and handling JSON serialization with numpy compatibility.""" import json from pathlib import Path @@ -14,19 +10,23 @@ class NumpyEncoder(json.JSONEncoder): - """ - JSON encoder that handles numpy data types. + """JSON encoder that handles numpy data types. This encoder extends the default `json.JSONEncoder` to serialize numpy - arrays, numpy data types. + arrays and numpy data types. + + Attributes: + Inherits all attributes from json.JSONEncoder. """ def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN401 - """ - Serialize objects with special handling for numpy. + """Serialize objects with special handling for numpy. - :param obj: Object to serialize. - :return: JSON-serializable representation of the object. + Args: + obj: Object to serialize. + + Returns: + JSON-serializable representation of the object. """ if isinstance(obj, np.integer): return int(obj) @@ -38,17 +38,19 @@ def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN def load_dataset(path: str | Path) -> Dataset: - """ - Load data from a specified path or use default sample data or load from hugging face hub. + """Load data from a specified path or use default sample data. This function loads a dataset from a JSON file or retrieves sample data included with the `autointent` package for default multiclass or multilabel - datasets. + datasets. If the path doesn't exist, it attempts to load from the Hugging Face hub. + + Args: + path: Path to the dataset file, or a predefined key: + - "default-multiclass": Loads sample multiclass dataset. + - "default-multilabel": Loads sample multilabel dataset. - :param data_path: Path to the dataset file, or a predefined key: - - "default-multiclass": Loads sample multiclass dataset. - - "default-multilabel": Loads sample multilabel dataset. - :return: A `Dataset` object containing the loaded data. + Returns: + A Dataset object containing the loaded data. """ if path == "default-multiclass": return Dataset.from_hub("AutoIntent/clinc150_subset") diff --git a/autointent/context/data_handler/_data_handler.py b/autointent/context/data_handler/_data_handler.py index 0a677c811..6a7291b5a 100644 --- a/autointent/context/data_handler/_data_handler.py +++ b/autointent/context/data_handler/_data_handler.py @@ -17,14 +17,17 @@ class RegexPatterns(TypedDict): - """Regex patterns for each intent class.""" + """Regex patterns for each intent class. + + Attributes: + id: Intent class id. + regex_full_match: Full match regex patterns. + regex_partial_match: Partial match regex patterns. + """ id: int - """Intent class id.""" regex_full_match: list[str] - """Full match regex patterns.""" regex_partial_match: list[str] - """Partial match regex patterns.""" class DataHandler: # TODO rename to Validator @@ -36,12 +39,12 @@ def __init__( config: DataConfig | None = None, random_seed: int = 0, ) -> None: - """ - Initialize the data handler. + """Initialize the data handler. - :param dataset: Training dataset. - :param random_seed: Seed for random number generation. - :param config: config + Args: + dataset: Training dataset. + config: Configuration object + random_seed: Seed for random number generation. """ set_seed(random_seed) self.random_seed = random_seed @@ -72,10 +75,10 @@ def __init__( @property def multilabel(self) -> bool: - """ - Check if the dataset is multilabel. + """Check if the dataset is multilabel. - :return: True if the dataset is multilabel, False otherwise. + Returns: + True if the dataset is multilabel, False otherwise. """ return self.dataset.multilabel @@ -89,29 +92,33 @@ def _choose_split(self, split_name: str, idx: int | None = None) -> str: return split def train_utterances(self, idx: int | None = None) -> list[str]: - """ - Retrieve training utterances from the dataset. + """Retrieve training utterances from the dataset. If a specific training split index is provided, retrieves utterances from the indexed training split. Otherwise, retrieves utterances from the primary training split. - :param idx: Optional index for a specific training split. - :return: List of training utterances. + Args: + idx: Optional index for a specific training split. + + Returns: + List of training utterances. """ split = self._choose_split(Split.TRAIN, idx) return cast(list[str], self.dataset[split][self.dataset.utterance_feature]) def train_labels(self, idx: int | None = None) -> ListOfGenericLabels: - """ - Retrieve training labels from the dataset. + """Retrieve training labels from the dataset. If a specific training split index is provided, retrieves labels from the indexed training split. Otherwise, retrieves labels from the primary training split. - :param idx: Optional index for a specific training split. - :return: List of training labels. + Args: + idx: Optional index for a specific training split. + + Returns: + List of training labels. """ split = self._choose_split(Split.TRAIN, idx) return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) @@ -120,58 +127,52 @@ def train_labels_folded(self) -> list[ListOfGenericLabels]: return [self.train_labels(j) for j in range(self.config.n_folds)] def validation_utterances(self, idx: int | None = None) -> list[str]: - """ - Retrieve validation utterances from the dataset. + """Retrieve validation utterances from the dataset. If a specific validation split index is provided, retrieves utterances from the indexed validation split. Otherwise, retrieves utterances from the primary validation split. - :param idx: Optional index for a specific validation split. - :return: List of validation utterances. + Args: + idx: Optional index for a specific validation split. + + Returns: + List of validation utterances. """ split = self._choose_split(Split.VALIDATION, idx) return cast(list[str], self.dataset[split][self.dataset.utterance_feature]) def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels: - """ - Retrieve validation labels from the dataset. + """Retrieve validation labels from the dataset. If a specific validation split index is provided, retrieves labels from the indexed validation split. Otherwise, retrieves labels from the primary validation split. - :param idx: Optional index for a specific validation split. - :return: List of validation labels. + Args: + idx: Optional index for a specific validation split. + + Returns: + List of validation labels. """ split = self._choose_split(Split.VALIDATION, idx) return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature]) def test_utterances(self) -> list[str] | None: - """ - Retrieve test utterances from the dataset. - - If a specific test split index is provided, retrieves utterances - from the indexed test split. Otherwise, retrieves utterances from - the primary test split. + """Retrieve test utterances from the dataset. - :param idx: Optional index for a specific test split. - :return: List of test utterances. + Returns: + List of test utterances. """ if Split.TEST not in self.dataset: return None return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature]) def test_labels(self) -> ListOfGenericLabels: - """ - Retrieve test labels from the dataset. - - If a specific test split index is provided, retrieves labels - from the indexed test split. Otherwise, retrieves labels from - the primary test split. + """Retrieve test labels from the dataset. - :param idx: Optional index for a specific test split. - :return: List of test labels. + Returns: + List of test labels. """ return cast(ListOfGenericLabels, self.dataset[Split.TEST][self.dataset.label_feature]) @@ -210,10 +211,12 @@ def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size raise ValueError(message) def _split_train(self, ratio: FloatFromZeroToOne) -> None: - """ - Split on two sets. + """Split on two sets. One is for scoring node optimizaton, one is for decision node. + + Args: + ratio: Split ratio """ self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TRAIN}_1"] = split_dataset( self.dataset, diff --git a/autointent/context/data_handler/_stratification.py b/autointent/context/data_handler/_stratification.py index 801d28408..a28d1469b 100644 --- a/autointent/context/data_handler/_stratification.py +++ b/autointent/context/data_handler/_stratification.py @@ -18,8 +18,7 @@ class StratifiedSplitter: - """ - A class for stratified splitting of datasets. + """A class for stratified splitting of datasets. This class provides methods to split a dataset into training and testing subsets while preserving the distribution of target labels. It supports both single-label @@ -33,13 +32,13 @@ def __init__( random_seed: int, shuffle: bool = True, ) -> None: - """ - Initialize the StratifiedSplitter. + """Initialize the StratifiedSplitter. - :param test_size: Proportion of the dataset to include in the test split. - :param label_feature: Name of the feature containing labels for stratification. - :param random_seed: Seed for random number generation to ensure reproducibility. - :param shuffle: Whether to shuffle the data before splitting. Defaults to True. + Args: + test_size: Proportion of the dataset to include in the test split. + label_feature: Name of the feature containing labels for stratification. + random_seed: Seed for random number generation to ensure reproducibility. + shuffle: Whether to shuffle the data before splitting. """ self.test_size = test_size self.label_feature = label_feature @@ -49,13 +48,18 @@ def __init__( def __call__( self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool | None = None ) -> tuple[HFDataset, HFDataset]: - """ - Split the dataset into training and testing subsets. + """Split the dataset into training and testing subsets. - :param dataset: The input dataset to be split. - :param multilabel: Whether the dataset is multi-label. - :param allow_oos_in_train: Set to True if you want to see out-of-scope utterances in train split. - :return: A tuple containing the training and testing datasets. + Args: + dataset: The input dataset to be split. + multilabel: Whether the dataset is multi-label. + allow_oos_in_train: Set to True if you want to see out-of-scope utterances in train split. + + Returns: + A tuple containing the training and testing datasets. + + Raises: + ValueError: If OOS samples are present but allow_oos_in_train is not specified. """ if not self._has_oos_samples(dataset): return self._split_without_oos(dataset, multilabel, self.test_size) @@ -69,15 +73,42 @@ def __call__( return splitter(dataset, multilabel) def _has_oos_samples(self, dataset: HFDataset) -> bool: + """Check if the dataset contains out-of-scope samples. + + Args: + dataset: The dataset to check. + + Returns: + True if the dataset contains OOS samples, False otherwise. + """ oos_samples = dataset.filter(lambda sample: sample[self.label_feature] is None) return len(oos_samples) > 0 def _split_without_oos(self, dataset: HFDataset, multilabel: bool, test_size: float) -> tuple[HFDataset, HFDataset]: + """Split dataset that doesn't contain OOS samples. + + Args: + dataset: Dataset to split. + multilabel: Whether the dataset is multi-label. + test_size: Proportion of the dataset to include in the test split. + + Returns: + A tuple containing training and testing datasets. + """ splitter = self._split_multilabel if multilabel else self._split_multiclass splits = splitter(dataset, test_size) return dataset.select(splits[0]), dataset.select(splits[1]) def _split_multiclass(self, dataset: HFDataset, test_size: float) -> Sequence[npt.NDArray[np.int_]]: + """Split multiclass dataset. + + Args: + dataset: Dataset to split. + test_size: Proportion of the dataset to include in the test split. + + Returns: + A sequence containing indices for train and test splits. + """ return train_test_split( # type: ignore[no-any-return] np.arange(len(dataset)), test_size=test_size, @@ -87,6 +118,15 @@ def _split_multiclass(self, dataset: HFDataset, test_size: float) -> Sequence[np ) def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[npt.NDArray[np.int_]]: + """Split multilabel dataset. + + Args: + dataset: Dataset to split. + test_size: Proportion of the dataset to include in the test split. + + Returns: + A sequence containing indices for train and test splits. + """ splitter = IterativeStratification( n_splits=2, order=2, @@ -95,11 +135,18 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature]))) def _split_allow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """ - Proportionally distribute OOS samples between two splits. + """Proportionally distribute OOS samples between two splits. + + Internally creates a dataset copy with some integer assigned as OOS class id. + With OOS samples treated as a separate class we obtain proportional distribution + of them between two splits. - Internally, this method creates a dataset copy with some integer assigned as OOS class id. - With OOS samples treated as a separate class we obtain proportional distribution of them between two splits. + Args: + dataset: Dataset to split. + multilabel: Whether the dataset is multi-label. + + Returns: + A tuple containing training and testing datasets. """ # add oos as a class if multilabel: @@ -126,30 +173,62 @@ def _split_allow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tup def _map_label( self, sample: dict[str, str | LabelType], old: LabelType, new: LabelType ) -> dict[str, str | LabelType]: + """Map labels from old value to new value. + + Args: + sample: Sample containing the label to map. + old: Old label value. + new: New label value. + + Returns: + Sample with mapped label. + """ if sample[self.label_feature] == old: sample[self.label_feature] = new return sample def _add_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> dict[str, str | LabelType]: - """Add OOS as a class for multi-label case.""" + """Add OOS as a class for multi-label case. + + Args: + sample: Sample to modify. + n_classes: Number of classes in the dataset. + + Returns: + Sample with added OOS label. + """ if sample[self.label_feature] is None: sample[self.label_feature] = [0] * n_classes sample[self.label_feature] += [1] # type: ignore[operator] return sample def _remove_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> dict[str, str | LabelType]: - """Remove OOS as a class for multi-label case.""" + """Remove OOS as a class for multi-label case. + + Args: + sample: Sample to modify. + n_classes: Number of classes in the dataset. + + Returns: + Sample with removed OOS label. + """ sample[self.label_feature] = sample[self.label_feature][:-1] # type: ignore[index] if sample[self.label_feature] == [0] * n_classes: sample[self.label_feature] = None # type: ignore[assignment] return sample def _split_disallow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """ - Move all OOS samples to test split. + """Move all OOS samples to test split. This method preserves the defined test_size proportion so you won't get unexpectedly large test set even you have lots of OOS samples. + + Args: + dataset: Dataset to split. + multilabel: Whether the dataset is multi-label. + + Returns: + A tuple containing training and testing datasets. """ in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) @@ -158,6 +237,14 @@ def _split_disallow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> return train, test def _separate_oos(self, dataset: HFDataset) -> tuple[HFDataset, HFDataset]: + """Separate OOS samples from in-domain samples. + + Args: + dataset: Dataset to separate. + + Returns: + A tuple containing in-domain and out-of-domain datasets. + """ in_domain_ids = [] out_of_domain_ids = [] for i, sample in enumerate(dataset): @@ -168,11 +255,17 @@ def _separate_oos(self, dataset: HFDataset) -> tuple[HFDataset, HFDataset]: return dataset.select(in_domain_ids), dataset.select(out_of_domain_ids) def _get_adjusted_test_size(self, n: int, k: int) -> float: - """ - Calculate effective test_size in order to preserve original proportion. + """Calculate effective test_size to preserve original proportion. - :param n: size of original dataset (both with in-domain and out-of-domain samples) - :param k: number of out-of-domain samples within the dataset + Args: + n: Size of original dataset (both with in-domain and out-of-domain samples). + k: Number of out-of-domain samples within the dataset. + + Returns: + Adjusted test size. + + Raises: + ValueError: If dataset contains too many OOS samples. """ if k == 0: return self.test_size @@ -193,22 +286,17 @@ def split_dataset( random_seed: int, allow_oos_in_train: bool | None = None, ) -> tuple[HFDataset, HFDataset]: - """ - Split a Dataset object into training and testing subsets. - - This function uses the StratifiedSplitter to perform stratified splitting - while preserving the distribution of labels. - - :param dataset: The dataset to be split, which must include training data. - :param split: The specific data split to be divided, e.g., "train" or - another split within the dataset. - :param test_size: Proportion of the dataset to include in the test split. - Should be a float value between 0.0 and 1.0, where 0.0 - means no data will be assigned to the test set, and 1.0 - means all data will be assigned to the test set. For example, - a value of 0.2 indicates 20% of the data will be used for testing. - :param random_seed: Seed for random number generation to ensure reproducibility. - :return: A tuple containing two subsets of the selected split. + """Split a Dataset object into training and testing subsets. + + Args: + dataset: The dataset to split, which must include training data. + split: The specific data split to divide. + test_size: Proportion of the dataset to include in the test split. + random_seed: Seed for random number generation. + allow_oos_in_train: Whether to allow OOS samples in train split. + + Returns: + A tuple containing two subsets of the selected split. """ splitter = StratifiedSplitter( test_size=test_size, diff --git a/autointent/context/optimization_info/_data_models.py b/autointent/context/optimization_info/_data_models.py index 4d5795109..fc56dd6e6 100644 --- a/autointent/context/optimization_info/_data_models.py +++ b/autointent/context/optimization_info/_data_models.py @@ -23,20 +23,25 @@ class RegexArtifact(Artifact): class EmbeddingArtifact(Artifact): - """ - Artifact containing details from the embedding node. + """Artifact containing details from the embedding node. - Name of the embedding model chosen after embedding optimization. + Attributes: + config: Configuration settings for the embedder. """ config: EmbedderConfig class ScorerArtifact(Artifact): - """ - Artifact containing outputs from the scoring node. + """Artifact containing outputs from the scoring node. Outputs from the best scorer, numpy arrays of shape (n_samples, n_classes). + + Attributes: + train_scores: Scorer outputs for train utterances. + validation_scores: Scorer outputs for validation utterances. + test_scores: Scorer outputs for test utterances. + folded_scores: Scores for each fold from cross-validation. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -49,11 +54,13 @@ class ScorerArtifact(Artifact): class DecisionArtifact(Artifact): - """ - Artifact containing outputs from the predictor node. + """Artifact containing outputs from the predictor node. Outputs from the best predictor, numpy array of shape (n_samples,) or - (n_samples, n_classes) depending on classification mode (multi-class or multi-label) + (n_samples, n_classes) depending on classification mode (multi-class or multi-label). + + Attributes: + labels: Predicted labels for the samples. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -61,12 +68,16 @@ class DecisionArtifact(Artifact): def validate_node_name(value: str) -> str: - """ - Validate and return the node type. + """Validate and return the node type. + + Args: + value: Node type as a string. - :param value: Node type as a string. - :return: Validated node type string. - :raises ValueError: If the node type is invalid. + Returns: + Validated node type string. + + Raises: + ValueError: If the node type is invalid. """ if value in [NodeType.embedding, NodeType.scoring, NodeType.decision, NodeType.regex]: return value @@ -75,10 +86,15 @@ def validate_node_name(value: str) -> str: class Artifacts(BaseModel): - """ - Container for storing and managing artifacts generated by pipeline nodes. + """Container for storing and managing artifacts generated by pipeline nodes. Modules hyperparams and outputs. The best ones are transmitted between nodes of the pipeline. + + Attributes: + regex: List of artifacts from the regex node. + embedding: List of artifacts from the embedding node. + scoring: List of artifacts from the scoring node. + decision: List of artifacts from the decision node. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -89,52 +105,67 @@ class Artifacts(BaseModel): decision: list[DecisionArtifact] = [] def add_artifact(self, node_type: str, artifact: Artifact) -> None: - """ - Add an artifact to the specified node type. + """Add an artifact to the specified node type. - :param node_type: Node type as a string. - :param artifact: The artifact to add. + Args: + node_type: Node type as a string. + artifact: The artifact to add. """ self.get_artifacts(node_type).append(artifact) def get_artifacts(self, node_type: str) -> list[Artifact]: - """ - Retrieve all artifacts for a specified node type. + """Retrieve all artifacts for a specified node type. + + Args: + node_type: Node type as a string. - :param node_type: Node type as a string. - :return: A list of artifacts for the node type. + Returns: + A list of artifacts for the node type. """ return getattr(self, validate_node_name(node_type)) # type: ignore[no-any-return] def get_best_artifact(self, node_type: str, idx: int) -> Artifact: - """ - Retrieve the best artifact for a specified node type and index. + """Retrieve the best artifact for a specified node type and index. - :param node_type: Node type as a string. - :param idx: Index of the artifact. - :return: The best artifact. + Args: + node_type: Node type as a string. + idx: Index of the artifact. + + Returns: + The best artifact. """ return self.get_artifacts(node_type)[idx] class Trial(BaseModel): - """Representation of an individual optimization trial.""" + """Representation of an individual optimization trial. + + Attributes: + module_name: Type of the module being optimized. + module_params: Parameters of the module for the trial. + metric_name: Name of the evaluation metric. + metric_value: Value of the evaluation metric for this trial. + module_dump_dir: Directory where the module is dumped. + metrics: Dictionary of metric names and their values. + """ module_name: str - """Type of the module being optimized.""" module_params: dict[str, Any] - """Parameters of the module for the trial.""" metric_name: str - """Name of the evaluation metric.""" metric_value: float - """Value of the evaluation metric for this trial.""" module_dump_dir: str | None - """Directory where the module is dumped.""" metrics: dict[str, float] class Trials(BaseModel): - """Container for managing optimization trials for pipeline nodes.""" + """Container for managing optimization trials for pipeline nodes. + + Attributes: + regex: List of trials for the regex node. + embedding: List of trials for the embedding node. + scoring: List of trials for the scoring node. + decision: List of trials for the decision node. + """ regex: list[Trial] = [] embedding: list[Trial] = [] @@ -142,60 +173,69 @@ class Trials(BaseModel): decision: list[Trial] = [] def get_trial(self, node_type: str, idx: int) -> Trial: - """ - Retrieve a specific trial for a node type and index. + """Retrieve a specific trial for a node type and index. + + Args: + node_type: Node type as a string. + idx: Index of the trial. - :param node_type: Node type as a string. - :param idx: Index of the trial. - :return: The requested trial. + Returns: + The requested trial. """ return self.get_trials(node_type)[idx] def get_trials(self, node_type: str) -> list[Trial]: - """ - Retrieve all trials for a specified node type. + """Retrieve all trials for a specified node type. - :param node_type: Node type as a string. - :return: A list of trials for the node type. + Args: + node_type: Node type as a string. + + Returns: + A list of trials for the node type. """ return getattr(self, validate_node_name(node_type)) # type: ignore[no-any-return] def add_trial(self, node_type: str, trial: Trial) -> None: - """ - Add a trial to a specified node type. + """Add a trial to a specified node type. - :param node_type: Node type as a string. - :param trial: The trial to add. + Args: + node_type: Node type as a string. + trial: The trial to add. """ self.get_trials(node_type).append(trial) class TrialsIds(BaseModel): - """Representation of the best trial IDs for each pipeline node.""" + """Representation of the best trial IDs for each pipeline node. + + Attributes: + regex: Best trial index for the regex node. + embedding: Best trial index for the embedding node. + scoring: Best trial index for the scoring node. + decision: Best trial index for the decision node. + """ regex: int | None = None - """Best trial index for the regex node.""" embedding: int | None = None - """Best trial index for the embedding node.""" scoring: int | None = None - """Best trial index for the scoring""" decision: int | None = None - """Best trial index for the decision node.""" def get_best_trial_idx(self, node_type: str) -> int | None: - """ - Retrieve the best trial index for a specified node type. + """Retrieve the best trial index for a specified node type. + + Args: + node_type: Node type as a string. - :param node_type: Node type as a string. - :return: The index of the best trial, or None if not set. + Returns: + The index of the best trial, or None if not set. """ return getattr(self, validate_node_name(node_type)) # type: ignore[no-any-return] def set_best_trial_idx(self, node_type: str, idx: int) -> None: - """ - Set the best trial index for a specified node type. + """Set the best trial index for a specified node type. - :param node_type: Node type as a string. - :param idx: Index of the best trial. + Args: + node_type: Node type as a string. + idx: Index of the best trial. """ setattr(self, validate_node_name(node_type), idx) diff --git a/autointent/context/optimization_info/_optimization_info.py b/autointent/context/optimization_info/_optimization_info.py index fadf91cc2..0a4901b12 100644 --- a/autointent/context/optimization_info/_optimization_info.py +++ b/autointent/context/optimization_info/_optimization_info.py @@ -17,12 +17,19 @@ from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials, TrialsIds if TYPE_CHECKING: - from autointent.modules.abc import BaseModule + from autointent.modules.base import BaseModule @dataclass class ModulesList: - """Container for managing lists of modules for each node type.""" + """Container for managing lists of modules for each node type. + + Attributes: + regex: List of modules for the regex node. + embedding: List of modules for the embedding node. + scoring: List of modules for the scoring node. + decision: List of modules for the decision node. + """ regex: list["BaseModule"] = field(default_factory=list) embedding: list["BaseModule"] = field(default_factory=list) @@ -30,31 +37,38 @@ class ModulesList: decision: list["BaseModule"] = field(default_factory=list) def get(self, node_type: str) -> list["BaseModule"]: - """ - Retrieve the list of modules for a specific node type. + """Retrieve the list of modules for a specific node type. + + Args: + node_type: The type of node (e.g., "regex", "embedding"). - :param node_type: The type of node (e.g., "regex", "embedding"). - :return: List of modules for the specified node type. + Returns: + List of modules for the specified node type. """ return getattr(self, node_type) # type: ignore[no-any-return] def add_module(self, node_type: str, module: "BaseModule") -> None: - """ - Add a module to the list for a specific node type. + """Add a module to the list for a specific node type. - :param node_type: The type of node. - :param module: The module to add. + Args: + node_type: The type of node. + module: The module to add. """ self.get(node_type).append(module) class OptimizationInfo: - """ - Tracks optimization results, including trials, artifacts, and modules. + """Tracks optimization results, including trials, artifacts, and modules. This class provides methods for logging optimization results, retrieving the best-performing modules and artifacts, and generating configuration for inference nodes. + + Attributes: + artifacts: Container for storing optimization artifacts. + trials: Container for storing optimization trials. + modules: Container for storing module instances. + pipeline_metrics: Dictionary storing pipeline-level metrics. """ def __init__(self) -> None: @@ -79,17 +93,18 @@ def log_module_optimization( module_dump_dir: str | None, module: "BaseModule | None" = None, ) -> None: - """ - Log optimization results for a module. - - :param node_type: Type of the node being optimized. - :param module_name: Type of the module. - :param module_params: Parameters of the module for the trial. - :param metric_value: Metric value achieved by the module. - :param metric_name: Name of the evaluation metric. - :param artifact: Artifact generated by the module. - :param module_dump_dir: Directory where the module is dumped. - :param module: The module instance, if available. + """Log optimization results for a module. + + Args: + node_type: Type of the node being optimized. + module_name: Type of the module. + module_params: Parameters of the module for the trial. + metric_value: Metric value achieved by the module. + metric_name: Name of the evaluation metric. + metrics: Dictionary of metric names and their values. + artifact: Artifact generated by the module. + module_dump_dir: Directory where the module is dumped. + module: The module instance, if available. """ trial = Trial( module_name=module_name, @@ -108,15 +123,24 @@ def log_module_optimization( self.artifacts.add_artifact(node_type, artifact) def _get_metrics_values(self, node_type: str) -> list[float]: - """Retrieve all metric values for a specific node type.""" + """Retrieve all metric values for a specific node type. + + Args: + node_type: Type of the node. + + Returns: + List of metric values. + """ return [trial.metric_value for trial in self.trials.get_trials(node_type)] def _get_best_trial_idx(self, node_type: str) -> int | None: - """ - Retrieve the index of the best trial for a node type. + """Retrieve the index of the best trial for a node type. + + Args: + node_type: Type of the node. - :param node_type: Type of the node. - :return: Index of the best trial, or None if no trials exist. + Returns: + Index of the best trial, or None if no trials exist. """ if not self.trials.get_trials(node_type): return None @@ -128,12 +152,16 @@ def _get_best_trial_idx(self, node_type: str) -> int | None: return best_idx def _get_best_artifact(self, node_type: str) -> EmbeddingArtifact | ScorerArtifact | Artifact: - """ - Retrieve the best artifact for a specific node type. + """Retrieve the best artifact for a specific node type. - :param node_type: Type of the node. - :return: The best artifact for the node type. - :raises ValueError: If no best trial exists for the node type. + Args: + node_type: Type of the node. + + Returns: + The best artifact for the node type. + + Raises: + ValueError: If no best trial exists for the node type. """ best_idx = self._get_best_trial_idx(node_type) if best_idx is None: @@ -142,55 +170,55 @@ def _get_best_artifact(self, node_type: str) -> EmbeddingArtifact | ScorerArtifa return self.artifacts.get_best_artifact(node_type, best_idx) def get_best_embedder(self) -> EmbedderConfig: - """ - Retrieve the name of the best embedder from the retriever node. + """Retrieve the name of the best embedder from the retriever node. - :return: Name of the best embedder. + Returns: + Configuration of the best embedder. """ best_retriever_artifact: EmbeddingArtifact = self._get_best_artifact(node_type=NodeType.embedding) # type: ignore[assignment] return best_retriever_artifact.config def get_best_train_scores(self) -> NDArray[np.float64] | None: - """ - Retrieve the train scores from the best scorer node. + """Retrieve the train scores from the best scorer node. - :return: Train scores as a numpy array. + Returns: + Train scores as a numpy array. """ best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment] return best_scorer_artifact.train_scores def get_best_validation_scores(self) -> NDArray[np.float64] | None: - """ - Retrieve the validation scores from the best scorer node. + """Retrieve the validation scores from the best scorer node. - :return: Validation scores as a numpy array. + Returns: + Validation scores as a numpy array. """ best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment] return best_scorer_artifact.validation_scores def get_best_folded_scores(self) -> list[NDArray[np.float64]] | None: - """ - Retrieve the validation scores from the best scorer node. + """Retrieve the validation scores from the best scorer node. - :return: Validation scores as a numpy array. + Returns: + Validation scores as a numpy array. """ best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment] return best_scorer_artifact.folded_scores def get_best_test_scores(self) -> NDArray[np.float64] | None: - """ - Retrieve the test scores from the best scorer node. + """Retrieve the test scores from the best scorer node. - :return: Test scores as a numpy array. + Returns: + Test scores as a numpy array. """ best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment] return best_scorer_artifact.test_scores def dump_evaluation_results(self) -> dict[str, Any]: - """ - Dump evaluation results for all nodes. + """Dump evaluation results for all nodes. - :return: Dictionary containing metrics and configurations for all nodes. + Returns: + Dictionary containing metrics and configurations for all nodes. """ node_wise_metrics = {node_type: self._get_metrics_values(node_type) for node_type in NodeType} return { @@ -200,10 +228,13 @@ def dump_evaluation_results(self) -> dict[str, Any]: } def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNodeConfig]: - """ - Generate configuration for inference nodes based on the best trials. + """Generate configuration for inference nodes based on the best trials. + + Args: + asdict: Whether to return the configuration as dictionaries. - :return: List of `InferenceNodeConfig` objects for inference nodes. + Returns: + List of configurations for inference nodes. """ trial_ids = [self._get_best_trial_idx(node_type) for node_type in NodeType] res = [] @@ -221,11 +252,13 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode return res # type: ignore[return-value] def _get_best_module(self, node_type: str) -> "BaseModule | None": - """ - Retrieve the best module for a specific node type. + """Retrieve the best module for a specific node type. + + Args: + node_type: Type of the node. - :param node_type: Type of the node. - :return: The best module, or None if no best trial exists. + Returns: + The best module, or None if no best trial exists. """ idx = self._get_best_trial_idx(node_type) if idx is not None: @@ -233,10 +266,10 @@ def _get_best_module(self, node_type: str) -> "BaseModule | None": return None def get_best_modules(self) -> dict[NodeType, "BaseModule"]: - """ - Retrieve the best modules for all node types. + """Retrieve the best modules for all node types. - :return: Dictionary of the best modules for each node type. + Returns: + Dictionary of the best modules for each node type. """ res = {nt: self._get_best_module(nt) for nt in NodeType} return {nt: m for nt, m in res.items() if m is not None} diff --git a/autointent/custom_types.py b/autointent/custom_types.py index d2b360717..c57a62546 100644 --- a/autointent/custom_types.py +++ b/autointent/custom_types.py @@ -58,13 +58,13 @@ class NodeType(str, Enum): class Split: - """ - Constants representing dataset splits. + """Enumeration of data splits in the AutoIntent framework. - :cvar str TRAIN: Training split. - :cvar str VALIDATION: Validation split. - :cvar str TEST: Testing split. - :cvar str INTENTS: Intents split. + Attributes: + TRAIN: Represents the training data split. + VALIDATION: Represents the validation data split. + TEST: Represents the test data split. + INTENTS: Represents the intents data split. """ TRAIN = "train" diff --git a/autointent/exceptions.py b/autointent/exceptions.py index 711916a6b..96b9a901b 100644 --- a/autointent/exceptions.py +++ b/autointent/exceptions.py @@ -2,40 +2,40 @@ class WrongClassificationError(Exception): - """ - Exception raised when a classification module is used with incompatible data. + """Exception raised when a classification module is used with incompatible data. This error typically occurs when a multiclass module is called on multilabel data or vice versa. - :param message: Error message, defaults to a standard incompatibility message. + Args: + message: Error message, defaults to a standard incompatibility message """ def __init__(self, message: str = "Multiclass module is called on multilabel data or vice-versa") -> None: - """ - Initialize the exception. + """Initialize the exception. - :param message: Error message, defaults to a standard incompatibility message. + Args: + message: Error message, defaults to a standard incompatibility message """ self.message = message super().__init__(message) class MismatchNumClassesError(Exception): - """ - Exception raised when the data contains an incompatible number of classes. + """Exception raised when the data contains an incompatible number of classes. This error indicates that the number of classes in the input data does not match the expected number of classes for the module. - :param message: Error message, defaults to a standard class incompatibility message. + Args: + message: Error message, defaults to a standard class incompatibility message """ def __init__(self, message: str | None = None) -> None: - """ - Initialize the exception. + """Initialize the exception. - :param message: Error message, defaults to a standard incompatibility message. + Args: + message: Error message, defaults to a standard incompatibility message """ self.message = ( message or "Provided scores number don't match with number of classes which module was trained on." diff --git a/autointent/generation/intents/description_generation.py b/autointent/generation/intents/description_generation.py index a3772669f..bc1a477c7 100644 --- a/autointent/generation/intents/description_generation.py +++ b/autointent/generation/intents/description_generation.py @@ -1,4 +1,9 @@ -"""Description generation for intents using OpenAI models.""" +"""Module for generating intent descriptions using OpenAI models. + +This module provides functionality to generate descriptions for intents using OpenAI's +language models. It includes utilities for grouping utterances, creating descriptions +for individual intents, and enhancing datasets with generated descriptions. +""" import asyncio import random @@ -12,19 +17,20 @@ def group_utterances_by_label(samples: list[Sample]) -> dict[int, list[str]]: - """ - Group samples by their labels. + """Group utterances from samples by their corresponding labels. - :param samples: List of samples with `label` and `utterance` attributes. + Args: + samples: List of samples, each containing a label and utterance. - :returns: A dictionary where labels map to lists of utterances. + Returns: + Dictionary mapping label IDs to lists of utterances. """ label_mapping = defaultdict(list) for sample in samples: match sample.label: case list(): - # parse one hot encoding + # Handle one-hot encoding for class_id, label in enumerate(sample.label): if label: label_mapping[class_id].append(sample.utterance) @@ -42,19 +48,22 @@ async def create_intent_description( prompt: PromptDescription, model_name: str, ) -> str: - """ - Generate a description for a specific intent using an OpenAI model. - - :param client: The OpenAI client instance used to communicate with the model. - :param intent_name: The name of the intent to describe. If None, an empty string will be used. - :param utterances: A list of example utterances related to the intent. - :param regex_patterns: A list of regular expression patterns associated with the intent. - - :param prompt: A string template for the prompt, which must include placeholders for {intent_name} - and {user_utterances} to format the content sent to the model. - :param model_name: The identifier of the OpenAI model to use for generating the description. - - :returns: The generated description of the intent. + """Generate a description for a specific intent using an OpenAI model. + + Args: + client: OpenAI client instance for model communication. + intent_name: Name of the intent to describe (empty string if None). + utterances: Example utterances related to the intent. + regex_patterns: Regular expression patterns associated with the intent. + prompt: Template for model prompt with placeholders for intent_name, + user_utterances, and regex_patterns. + model_name: Identifier of the OpenAI model to use. + + Returns: + Generated description of the intent. + + Raises: + TypeError: If the model response is not a string. """ intent_name = intent_name if intent_name is not None else "" utterances = random.sample(utterances, min(5, len(utterances))) @@ -85,17 +94,18 @@ async def generate_intent_descriptions( prompt: PromptDescription, model_name: str, ) -> list[Intent]: - """ - Generate descriptions for a list of intents using an OpenAI model. - - :param client: The OpenAI client used to generate the descriptions. - :param intent_utterances: A dictionary mapping intent IDs to their corresponding utterances. - :param intents: A list of intents to generate descriptions for. - :param prompt: A string template for the prompt, which must include placeholders for {intent_name} - and {user_utterances} to format the content sent to the model. - :param model_name: The name of the OpenAI model to use for generating descriptions. - - :returns: The list of intents with updated descriptions. + """Generate descriptions for multiple intents using an OpenAI model. + + Args: + client: OpenAI client for generating descriptions. + intent_utterances: Dictionary mapping intent IDs to utterances. + intents: List of intents needing descriptions. + prompt: Template for model prompt with placeholders for intent_name, + user_utterances, and regex_patterns. + model_name: Name of the OpenAI model to use. + + Returns: + List of intents with updated descriptions. """ tasks = [] for intent in intents: @@ -127,16 +137,17 @@ def enhance_dataset_with_descriptions( prompt: PromptDescription, model_name: str = "gpt-4o-mini", ) -> Dataset: - """ - Enhances a dataset by generating descriptions for intents using an OpenAI model. + """Enhance a dataset by adding generated descriptions to its intents. - :param dataset: The dataset containing utterances and intents that require descriptions. - :param client: The OpenAI client used to generate the descriptions. - :param prompt: A string template for the prompt, which must include placeholders for {intent_name} - and {user_utterances} to format the content sent to the model. - :param model_name: The OpenAI model to use for generating descriptions. + Args: + dataset: Dataset containing utterances and intents needing descriptions. + client: OpenAI client for generating descriptions. + prompt: Template for model prompt with placeholders for intent_name, + user_utterances, and regex_patterns. + model_name: OpenAI model identifier for generating descriptions. - :returns: The dataset with intents enhanced by generated descriptions. + Returns: + Dataset with enhanced intent descriptions. """ samples = [] for split in dataset.values(): diff --git a/autointent/generation/intents/prompt_scheme.py b/autointent/generation/intents/prompt_scheme.py index 0e15a0826..90384b771 100644 --- a/autointent/generation/intents/prompt_scheme.py +++ b/autointent/generation/intents/prompt_scheme.py @@ -20,10 +20,13 @@ class PromptDescription(BaseModel): @classmethod @field_validator("text") def check_valid_prompt(cls, value: str) -> str: - """ - Validate the prompt description template. + """Validate the prompt description template. + + Args: + value: The prompt description template. - :param value: Check the prompt description template. + Returns: + The validated prompt description template. """ if value.find("{intent_name}") == -1 or value.find("{user_utterances}") == -1: text_error = ( diff --git a/autointent/generation/utterances/__init__.py b/autointent/generation/utterances/__init__.py index 14d7c6048..533151455 100644 --- a/autointent/generation/utterances/__init__.py +++ b/autointent/generation/utterances/__init__.py @@ -26,7 +26,6 @@ "IncrementalUtteranceEvolver", "InformalEvolution", "ReasoningEvolution", - "SynthesizerChatTemplate", "UtteranceEvolver", "UtteranceGenerator", ] diff --git a/autointent/generation/utterances/balancer.py b/autointent/generation/utterances/balancer.py index 4be7772f5..baaf68caf 100644 --- a/autointent/generation/utterances/balancer.py +++ b/autointent/generation/utterances/balancer.py @@ -24,16 +24,15 @@ def __init__( async_mode: bool = False, max_samples_per_class: int | None = None, ) -> None: - """ - Initialize the UtteranceBalancer. + """Initialize the UtteranceBalancer. Args: generator (Generator): The generator object used to create utterances. prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator. - seed (int, optional): The seed for random number generation. Defaults to 42. async_mode (bool, optional): Whether to run the generator in asynchronous mode. Defaults to False. max_samples_per_class (int | None, optional): The maximum number of samples per class. Must be a positive integer or None. Defaults to None. + Raises: ValueError: If max_samples_per_class is not None and is less than or equal to 0. """ @@ -47,12 +46,10 @@ def __init__( self.max_samples = max_samples_per_class def balance(self, dataset: Dataset, split: str = Split.TRAIN, batch_size: int = 4) -> Dataset: - """ - Balances the specified dataset split. + """Balances the specified dataset split. :param dataset: Source dataset :param split: Target split for balancing - :param n_evolutions: Number of augmentations per example :param batch_size: Batch size for asynchronous processing :return: Balanced dataset """ @@ -142,7 +139,11 @@ def _augment_class(self, dataset: Dataset, split: str, class_id: int, needed: in logger.debug("Total samples after augmentation: %s", final_count) def _process_utterances(self, generated: list[str]) -> list[str]: - """Process and clean generated utterances.""" + """Process and clean generated utterances. + + Args: + generated: Generated list + """ processed = [] for ut in generated: if "', '" in ut or "',\n" in ut: diff --git a/autointent/generation/utterances/basic/chat_templates/_base.py b/autointent/generation/utterances/basic/chat_templates/_base.py index 7c66d716e..c14bc07eb 100644 --- a/autointent/generation/utterances/basic/chat_templates/_base.py +++ b/autointent/generation/utterances/basic/chat_templates/_base.py @@ -15,7 +15,15 @@ class BaseChatTemplate(ABC): @abstractmethod def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]: - """Generate examples for this intent.""" + """Generate a list of messages to request additional examples for the given intent. + + Args: + intent_data: Intent data for which to generate examples. + n_examples: Number of examples to generate. + + Returns: + List of messages for the chat template. + """ class BaseSynthesizerTemplate(BaseChatTemplate): @@ -33,7 +41,17 @@ def __init__( extra_instructions: str | None = None, max_sample_utterances: int | None = None, ) -> None: - """Initialize the chat template with dataset, split, and optional instructions.""" + """Initialize the BaseSynthesizerTemplate. + + Args: + dataset: Dataset to use for generating examples. + split: Dataset split to use for generating examples. + extra_instructions: Additional instructions for the model. + max_sample_utterances: Maximum number of sample utterances to include. + + Raises: + ValueError: If the dataset is not provided. + """ if extra_instructions is None: extra_instructions = "" @@ -47,7 +65,15 @@ def __init__( self.max_sample_utterances = max_sample_utterances def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]: - """Generate a list of messages to request additional examples for the given intent.""" + """Generate a list of messages to request additional examples for the given intent. + + Args: + intent_data: Intent data for which to generate examples. + n_examples: Number of examples to generate. + + Returns: + List of messages for the chat template. + """ in_domain_samples = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] is not None) if self.dataset.multilabel: filter_fn = lambda sample: sample[Dataset.label_feature][intent_data.id] == 1 # noqa: E731 @@ -66,6 +92,16 @@ def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]: ] def _create_final_message(self, intent_data: Intent, n_examples: int, sample_utterances: list[str]) -> Message: + """Create the final message for the chat template. + + Args: + intent_data: Intent data for which to generate examples. + n_examples: Number of examples to generate. + sample_utterances: Sample utterances to include. + + Returns: + The final message for the chat template. + """ content = f"{self._INTENT_NAME_LABEL}: {intent_data.name}\n\n{self._EXAMPLE_UTTERANCES_LABEL}:\n" if sample_utterances: diff --git a/autointent/generation/utterances/basic/utterance_generator.py b/autointent/generation/utterances/basic/utterance_generator.py index a69df3c85..c84524370 100644 --- a/autointent/generation/utterances/basic/utterance_generator.py +++ b/autointent/generation/utterances/basic/utterance_generator.py @@ -13,8 +13,7 @@ class UtteranceGenerator: - """ - Basic generation of new utterances from existing ones. + """Basic generation of new utterances from existing ones. This augmentation method simply prompts LLM to look at existing examples and generate similar. Additionally, it can consider some aspects of style, @@ -22,19 +21,41 @@ class UtteranceGenerator: """ def __init__(self, generator: Generator, prompt_maker: BaseSynthesizerTemplate, async_mode: bool = False) -> None: - """Initialize.""" + """Initialize the UtteranceGenerator. + + Args: + generator: Generator instance for generating utterances. + prompt_maker: Prompt maker instance for generating prompts. + async_mode: Whether to use asynchronous mode for generation. + """ self.generator = generator self.prompt_maker = prompt_maker self.async_mode = async_mode def __call__(self, intent_data: Intent, n_generations: int) -> list[str]: - """Generate new utterances.""" + """Call the generator to generate new utterances. + + Args: + intent_data: Intent data for which to generate utterances. + n_generations: Number of utterances to generate. + + Returns: + List of generated utterances. + """ messages = self.prompt_maker(intent_data, n_generations) response_text = self.generator.get_chat_completion(messages) return _extract_utterances(response_text) async def _call_async(self, intent_data: Intent, n_generations: int) -> list[str]: - """Generate new utterances asynchronously.""" + """Call the generator to generate new utterances asynchronously. + + Args: + intent_data: Intent data for which to generate utterances. + n_generations: Number of utterances to generate. + + Returns: + List of generated utterances. + """ messages = self.prompt_maker(intent_data, n_generations) response_text = await self.generator.get_chat_completion_async(messages) return _extract_utterances(response_text) @@ -47,15 +68,17 @@ def augment( update_split: bool = True, batch_size: int = 4, ) -> list[Sample]: - """ - Augment some split of dataset. - - :param dataset: Dataset object - :param split_name: Dataset split (default is TRAIN) - :param n_generations: Number of utterances to generate per intent - :param update_split: Whether to update the dataset split - :param batch_size: Batch size for async generation - :return: List of generated samples + """Augment some split of dataset. + + Args: + dataset: Dataset object. + split_name: Dataset split (default is TRAIN). + n_generations: Number of utterances to generate per intent. + update_split: Whether to update the dataset split. + batch_size: Batch size for async generation. + + Returns: + List of generated samples. """ if self.async_mode: return asyncio.run(self._augment_async(dataset, split_name, n_generations, update_split, batch_size)) @@ -82,15 +105,17 @@ async def _augment_async( update_split: bool = True, batch_size: int = 4, ) -> list[Sample]: - """ - Augment some split of dataset asynchronously in batches. - - :param dataset: Dataset object - :param split_name: Dataset split (default is TRAIN) - :param n_generations: Number of utterances to generate per intent - :param update_split: Whether to update the dataset split - :param batch_size: Batch size for async generation - :return: List of generated samples + """Augment some split of dataset asynchronously. + + Args: + dataset: Dataset object. + split_name: Dataset split (default is TRAIN). + n_generations: Number of utterances to generate per intent. + update_split: Whether to update the dataset split. + batch_size: Batch size for async generation. + + Returns: + List of generated samples. """ original_split = dataset[split_name] new_samples = [] @@ -116,10 +141,13 @@ async def _augment_async( def _extract_utterances(response_text: str) -> list[str]: - """ - Parse LLM output. + """Extract utterances from LLM output. + + Args: + response_text: Response text from LLM. - Inverse function to :py:func:`_format_utterances`. + Returns: + List of utterances. """ raw_utterances = response_text.split("\n") # remove enumeration diff --git a/autointent/generation/utterances/evolution/chat_templates/base.py b/autointent/generation/utterances/evolution/chat_templates/base.py index 1a6c7bb1f..3f83feaf4 100644 --- a/autointent/generation/utterances/evolution/chat_templates/base.py +++ b/autointent/generation/utterances/evolution/chat_templates/base.py @@ -13,7 +13,15 @@ class EvolutionChatTemplate: name: str def __call__(self, utterance: str, intent_data: Intent) -> list[Message]: - """Make a chat to complete by LLM.""" + """Generate a list of messages to request additional examples for the given intent. + + Args: + utterance: Utterance to be used for generation. + intent_data: Intent data for which to generate examples. + + Returns: + List of messages for the chat template. + """ invoke_message = Message( role=Role.USER, content=f"Intent name: {intent_data.name or ''}\nUtterance: {utterance}", diff --git a/autointent/generation/utterances/evolution/evolver.py b/autointent/generation/utterances/evolution/evolver.py index 8c70bdc09..fef446a25 100644 --- a/autointent/generation/utterances/evolution/evolver.py +++ b/autointent/generation/utterances/evolution/evolver.py @@ -1,5 +1,4 @@ -""" -Evolutionary strategy to augmenting utterances. +"""Evolutionary strategy to augmenting utterances. Deeply inspired by DeepEval evolutions. """ @@ -19,8 +18,7 @@ class UtteranceEvolver: - """ - Evolutionary strategy to augmenting utterances. + """Evolutionary strategy to augmenting utterances. Deeply inspired by DeepEval evolutions. This method takes single utterance and prompts LLM to change it in a specific way. @@ -33,20 +31,43 @@ def __init__( seed: int = 0, async_mode: bool = False, ) -> None: - """Initialize.""" + """Initialize the UtteranceEvolver. + + Args: + generator: Generator instance for generating utterances. + prompt_makers: List of prompt makers for generating prompts. + seed: Random seed for reproducibility. + async_mode: Whether to use asynchronous mode for generation. + """ self.generator = generator self.prompt_makers = prompt_makers self.async_mode = async_mode random.seed(seed) def _evolve(self, utterance: str, intent_data: Intent) -> str: - """Apply evolutions single time synchronously.""" + """Apply evolutions a single time. + + Args: + utterance: Utterance to be evolved. + intent_data: Intent data for which to evolve the utterance. + + Returns: + Evolved utterance. + """ maker = random.choice(self.prompt_makers) chat = maker(utterance, intent_data) return self.generator.get_chat_completion(chat) async def _evolve_async(self, utterance: str, intent_data: Intent) -> str: - """Apply evolutions a single time (asynchronously).""" + """Apply evolutions a single time asynchronously. + + Args: + utterance: Utterance to be evolved. + intent_data: Intent data for which to evolve the utterance. + + Returns: + Evolved utterance. + """ maker = random.choice(self.prompt_makers) chat = maker(utterance, intent_data) return await self.generator.get_chat_completion_async(chat) @@ -54,7 +75,17 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str: def __call__( self, utterance: str, intent_data: Intent, n_evolutions: int = 1, sequential: bool = False ) -> list[str]: - """Apply evolutions multiple times (synchronously).""" + """Apply evolutions to the utterance. + + Args: + utterance: Utterance to be evolved. + intent_data: Intent data for which to evolve the utterance. + n_evolutions: Number of evolutions to apply. + sequential: Whether to apply evolutions sequentially. + + Returns: + List of evolved utterances. + """ current_utterance = utterance generated_utterances = [] @@ -76,10 +107,18 @@ def augment( batch_size: int = 4, sequential: bool = False, ) -> HFDataset: - """ - Augment some split of dataset. - - Note that for now it supports only single-label datasets. + """Augment some split of dataset. + + Args: + dataset: Dataset object. + split_name: Dataset split (default is TRAIN). + n_evolutions: Number of evolutions to apply. + update_split: Whether to update the dataset split. + batch_size: Batch size for async generation. + sequential: Whether to apply evolutions sequentially. + + Returns: + List of generated samples. """ if self.async_mode: if sequential: @@ -123,6 +162,18 @@ async def _augment_async( update_split: bool = True, batch_size: int = 4, ) -> HFDataset: + """Augment some split of dataset asynchronously. + + Args: + dataset: Dataset object. + split_name: Dataset split (default is TRAIN). + n_evolutions: Number of evolutions to apply. + update_split: Whether to update the dataset split. + batch_size: Batch size for async generation. + + Returns: + List of generated samples. + """ original_split = dataset[split_name] new_samples = [] diff --git a/autointent/generation/utterances/evolution/incremental_evolver.py b/autointent/generation/utterances/evolution/incremental_evolver.py index f04e76fd4..80aeea588 100644 --- a/autointent/generation/utterances/evolution/incremental_evolver.py +++ b/autointent/generation/utterances/evolution/incremental_evolver.py @@ -1,5 +1,4 @@ -""" -Evolutionary strategy to augmenting utterances. +"""Evolutionary strategy to augmenting utterances. Deeply inspired by DeepEval evolutions. """ @@ -51,11 +50,27 @@ def __init__( async_mode: bool = False, search_space: str | None = None, ) -> None: - """Initialize.""" + """Initialize the IncrementalUtteranceEvolver. + + Args: + generator: Generator instance for generating utterances. + prompt_makers: List of prompt makers for generating prompts. + seed: Random seed for reproducibility. + async_mode: Whether to use asynchronous mode for generation. + search_space: Search space for the pipeline optimizer. + """ super().__init__(generator, prompt_makers, seed, async_mode) self.search_space = self._choose_search_space(search_space) def _choose_search_space(self, search_space: str | None) -> list[dict[str, Any]] | Path | str: + """Choose search space for the pipeline optimizer. + + Args: + search_space: Search space for the pipeline optimizer. If None, default search space is used. + + Returns: + The chosen search space. + """ if search_space is None: return SEARCH_SPACE return search_space @@ -69,10 +84,18 @@ def augment( batch_size: int = 4, sequential: bool = False, ) -> HFDataset: - """ - Augment some split of dataset. - - Note that for now it supports only single-label datasets. + """Augment some split of dataset. + + Args: + dataset: Dataset object. + split_name: Dataset split (default is TRAIN). + n_evolutions: Number of evolutions to perform. + update_split: Whether to update the dataset split with the new samples. + batch_size: Batch size for augmentation. + sequential: Whether to perform augmentations sequentially. + + Returns: + List of generated samples. """ best_result = 0 merge_dataset = copy.deepcopy(dataset) diff --git a/autointent/generation/utterances/generator.py b/autointent/generation/utterances/generator.py index 3d5aa592a..043589fda 100644 --- a/autointent/generation/utterances/generator.py +++ b/autointent/generation/utterances/generator.py @@ -22,14 +22,13 @@ class Generator: } def __init__(self, base_url: str | None = None, model_name: str | None = None, **generation_params: Any) -> None: # noqa: ANN401 - """ - Initialize the wrapper for LLM. + """Initialize the wrapper for LLM. - :param base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server. - Omit this to infer OPENAI_BASE_URL from environment. - :param model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment. - :param generation_params: kwargs that will be sent with a request to the endpoint. - Omit this to use AutoIntent's default parameters. + Args: + base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server. + Omit this to infer OPENAI_BASE_URL from environment. + model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment. + **generation_params: kwargs that will be sent with a request to the endpoint. """ if not base_url: base_url = os.environ["OPENAI_BASE_URL"] @@ -44,7 +43,14 @@ def __init__(self, base_url: str | None = None, model_name: str | None = None, * } # https://stackoverflow.com/a/65539348 def get_chat_completion(self, messages: list[Message]) -> str: - """Prompt LLM and return its answer synchronously.""" + """Prompt LLM and return its answer. + + Args: + messages: List of messages to send to the model. + + Returns: + Model's response. + """ response = self.client.chat.completions.create( messages=messages, # type: ignore[arg-type] model=self.model_name, @@ -53,7 +59,14 @@ def get_chat_completion(self, messages: list[Message]) -> str: return response.choices[0].message.content # type: ignore[return-value] async def get_chat_completion_async(self, messages: list[Message]) -> str: - """Prompt LLM and return its answer asynchronously.""" + """Prompt LLM and return its answer asynchronously. + + Args: + messages: List of messages to send to the model. + + Returns: + Model's response. + """ response = await self.async_client.chat.completions.create( messages=messages, # type: ignore[arg-type] model=self.model_name, diff --git a/autointent/metrics/_converter.py b/autointent/metrics/_converter.py index 59bb99cbf..6db2f7954 100644 --- a/autointent/metrics/_converter.py +++ b/autointent/metrics/_converter.py @@ -15,12 +15,14 @@ def transform( y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE | CANDIDATE_TYPE | SCORES_VALUE_TYPE, ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: - """ - Transform y_true and y_pred to numpy arrays. + """Transform y_true and y_pred to numpy arrays. + + Args: + y_true: Y_true values + y_pred: Y_pred values - :param y_true: Y_true values - :param y_pred: Y_pred values - :return: + Returns: + Tuple of numpy arrays (y_true, y_pred) """ y_pred_ = np.array(y_pred) y_true_ = np.array(y_true) diff --git a/autointent/metrics/decision.py b/autointent/metrics/decision.py index f32cc1620..df18aa4f4 100644 --- a/autointent/metrics/decision.py +++ b/autointent/metrics/decision.py @@ -19,20 +19,29 @@ class DecisionMetricFn(Protocol): """Protocol for decision metrics.""" def __call__(self, y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - """ - Calculate decision metric. - - :param y_true: True values of labels - - multiclass case: list representing an array shape `(n_samples,)` of integer class labels - - multilabel case: list representing a matrix of shape `(n_samples, n_classes)` with binary values - :param y_pred: Predicted values of labels. Same shape as `y_true` - :return: Score of the decision metric + """Calculate decision metric. + + Args: + y_true: True values of labels + - multiclass case: list representing an array shape `(n_samples,)` of integer class labels + - multilabel case: list representing a matrix of shape `(n_samples, n_classes)` with binary values + y_pred: Predicted values of labels. Same shape as `y_true` + Returns: + Score of the decision metric """ ... def handle_oos(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> tuple[ListOfLabels, ListOfLabels]: - """Convert labels of OOS samples to make them usable in decision metrics.""" + """Convert labels of OOS samples to make them usable in decision metrics. + + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Tuple of transformed true and predicted labels + """ in_domain_labels = list(filter(lambda lab: lab is not None, y_true)) if isinstance(in_domain_labels[0], list): func = _add_oos_multilabel @@ -45,20 +54,37 @@ def handle_oos(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> tupl def _add_oos_multiclass(label: int | None, n_classes: int) -> int: + """Add OOS label for multiclass classification. + + Args: + label: Original label + n_classes: Number of classes + + Returns: + Transformed label + """ if label is None: return n_classes return label def _add_oos_multilabel(label: list[int] | None, n_classes: int) -> list[int]: + """Add OOS label for multilabel classification. + + Args: + label: Original label + n_classes: Number of classes + + Returns: + Transformed label + """ if label is None: return [0] * n_classes + [1] return [*label, 1] def decision_accuracy(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - r""" - Calculate decision accuracy. Supports both multiclass and multilabel. + r"""Calculate decision accuracy. Supports both multiclass and multilabel. The decision accuracy is calculated as: @@ -73,17 +99,19 @@ def decision_accuracy(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition is true and 0 otherwise. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision accuracy + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision accuracy """ y_true_, y_pred_ = transform(*handle_oos(y_true, y_pred)) return float(np.mean(y_true_ == y_pred_)) def _decision_roc_auc_multiclass(y_true: npt.NDArray[Any], y_pred: npt.NDArray[Any]) -> float: - r""" - Calculate roc_auc for multiclass. + r"""Calculate ROC AUC for multiclass. The ROC AUC score for multiclass is calculated as the mean ROC AUC score across all classes, where each class is treated as a binary classification task @@ -98,9 +126,12 @@ def _decision_roc_auc_multiclass(y_true: npt.NDArray[Any], y_pred: npt.NDArray[A - :math:`\text{ROC AUC}_k` is the ROC AUC score for the :math:`k`-th class, calculated by treating it as a binary classification problem (class :math:`k` vs rest). - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision roc_auc + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision ROC AUC """ n_classes = len(np.unique(y_true)) roc_auc_scores: list[float] = [] @@ -113,31 +144,35 @@ def _decision_roc_auc_multiclass(y_true: npt.NDArray[Any], y_pred: npt.NDArray[A def _decision_roc_auc_multilabel(y_true: npt.NDArray[Any], y_pred: npt.NDArray[Any]) -> float: - r""" - Calculate roc_auc for multilabel. + r"""Calculate ROC AUC for multilabel. This function internally uses :func:`sklearn.metrics.roc_auc_score` with `average=macro`. Refer to the `scikit-learn documentation `__ for more details. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision accuracy + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision ROC AUC """ return float(roc_auc_score(y_true, y_pred, average="macro")) def decision_roc_auc(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - r""" - Calculate ROC AUC for multiclass and multilabel classification. + r"""Calculate ROC AUC for multiclass and multilabel classification. The ROC AUC measures the ability of a model to distinguish between classes. It is calculated as the area under the curve of the true positive rate (TPR) against the false positive rate (FPR) at various threshold settings. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision ROC AUC + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision ROC AUC """ y_true_, y_pred_ = transform(*handle_oos(y_true, y_pred)) if y_pred_.ndim == y_true_.ndim == 1: @@ -150,45 +185,51 @@ def decision_roc_auc(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) - def decision_precision(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - r""" - Calculate decision precision. Supports both multiclass and multilabel. + r"""Calculate decision precision. Supports both multiclass and multilabel. This function internally uses :func:`sklearn.metrics.precision_score` with `average=macro`. Refer to the `scikit-learn documentation `__ for more details. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision precision + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision precision """ return float(precision_score(*handle_oos(y_true, y_pred), average="macro")) def decision_recall(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - r""" - Calculate decision recall. Supports both multiclass and multilabel. + r"""Calculate decision recall. Supports both multiclass and multilabel. This function internally uses :func:`sklearn.metrics.recall_score` with `average=macro`. Refer to the `scikit-learn documentation `__ for more details. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision recall + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision recall """ return float(recall_score(*handle_oos(y_true, y_pred), average="macro")) def decision_f1(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - r""" - Calculate decision f1 score. Supports both multiclass and multilabel. + r"""Calculate decision F1 score. Supports both multiclass and multilabel. This function internally uses :func:`sklearn.metrics.f1_score` with `average=macro`. Refer to the `scikit-learn documentation `__ for more details. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the decision accuracy + Args: + y_true: True values of labels + y_pred: Predicted values of labels + + Returns: + Score of the decision F1 score """ return float(f1_score(*handle_oos(y_true, y_pred), average="macro")) diff --git a/autointent/metrics/regex.py b/autointent/metrics/regex.py index a953b1b84..29a8b5099 100644 --- a/autointent/metrics/regex.py +++ b/autointent/metrics/regex.py @@ -12,19 +12,20 @@ class RegexMetricFn(Protocol): """Protocol for regex metrics.""" def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float: - """ - Calculate regex metric. + """Calculate regex metric. + + Args: + y_true: True values of labels. + y_pred: Predicted values of labels. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the regex metric + Returns: + Score of the regex metric. """ ... def regex_partial_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float: - r""" - Calculate regex partial accuracy. + r"""Calculate regex partial accuracy. The regex partial accuracy is calculated as: @@ -37,30 +38,32 @@ def regex_partial_accuracy(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) - :math:`y_{\text{true},i}` is the true label for the :math:`i`-th sample, - :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample, - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition - is true and 0 otherwise. + is true and 0 otherwise. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the regex metric + Args: + y_true: True values of labels. + y_pred: Predicted values of labels. + + Returns: + Score of the regex partial accuracy. """ y_true_, y_pred_ = transform(y_true, y_pred) correct = np.mean([true in pred for true, pred in zip(y_true_, y_pred_, strict=True)]) total = y_true_.shape[0] if total == 0: - return -1 # TODO think about it + return -1 # TODO: think about it return float(correct / total) def regex_partial_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> float: - r""" - Calculate regex partial precision. + r"""Calculate regex partial precision. The regex partial precision is calculated as: .. math:: - \text{Partial Precision} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i} - \in y_{\text{pred},i})}{\sum_{i=1}^N \mathbb{1}(|y_{\text{pred},i}| > 0)} + \text{Partial Precision} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{true},i} \in y_{\text{pred},i})}{\sum_{i=1}^N + \mathbb{1}(|y_{\text{pred},i}| > 0)} where: - :math:`N` is the total number of samples, @@ -68,11 +71,14 @@ def regex_partial_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE - :math:`y_{\text{pred},i}` is the predicted label for the :math:`i`-th sample, - :math:`|y_{\text{pred},i}|` is the number of predicted labels for the :math:`i`-th sample, - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition - is true and 0 otherwise. + is true and 0 otherwise. + + Args: + y_true: True values of labels. + y_pred: Predicted values of labels. - :param y_true: True values of labels - :param y_pred: Predicted values of labels - :return: Score of the regex metric + Returns: + Score of the regex partial precision. """ y_true_, y_pred_ = transform(y_true, y_pred) diff --git a/autointent/metrics/retrieval.py b/autointent/metrics/retrieval.py index a9546ce9a..caf80626b 100644 --- a/autointent/metrics/retrieval.py +++ b/autointent/metrics/retrieval.py @@ -19,8 +19,7 @@ def __call__( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - """ - Calculate retrieval metric. + """Calculate retrieval metric. - multiclass case: labels are integer - multilabel case: labels are binary @@ -41,8 +40,7 @@ def _macrofy( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Extend single-label `metric_fn` to a multi-label case via macro averaging. + r"""Extend single-label `metric_fn` to a multi-label case via macro averaging. The macro-average score is calculated as: @@ -78,8 +76,7 @@ def _macrofy( def _average_precision(query_label: int, candidate_labels: npt.NDArray[np.int64], k: int | None = None) -> float: - r""" - Calculate the average precision at position k. + r"""Calculate the average precision at position k. The average precision is calculated as: @@ -126,8 +123,7 @@ def wrapper(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, @ignore_oos def retrieval_map(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float: - r""" - Calculate the mean average precision at position k. + r"""Calculate the mean average precision at position k. The Mean Average Precision (MAP) is computed as the average of the average precision (AP) scores for all queries. The average precision for a single query computes the precision at each rank @@ -158,8 +154,7 @@ def retrieval_map(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_ def _average_precision_intersecting( query_label: list[int], candidate_labels: CANDIDATE_TYPE, k: int | None = None ) -> float: - r""" - Calculate the average precision at position k for the intersecting labels. + r"""Calculate the average precision at position k for the intersecting labels. The average precision for intersecting labels is calculated as: @@ -202,8 +197,7 @@ def retrieval_map_intersecting( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the mean average precision at position k for the intersecting labels. + r"""Calculate the mean average precision at position k for the intersecting labels. The Mean Average Precision (MAP) for intersecting labels is computed as the average of the average precision (AP) scores for all queries. The average @@ -238,8 +232,7 @@ def retrieval_map_macro( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the mean average precision at position k for the intersecting labels. + r"""Calculate the mean average precision at position k for the intersecting labels. This function internally uses :func:`retrieval_map` to calculate the MAP for each query and performs macro-averaging across multiple queries. @@ -259,8 +252,7 @@ def retrieval_hit_rate( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the hit rate at position k. + r"""Calculate the hit rate at position k. The hit rate is calculated as: @@ -299,8 +291,7 @@ def retrieval_hit_rate_intersecting( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the hit rate at position k for the intersecting labels. + r"""Calculate the hit rate at position k for the intersecting labels. The intersecting hit rate is calculated as: @@ -345,8 +336,7 @@ def retrieval_hit_rate_macro( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the hit rate at position k for the intersecting labels. + r"""Calculate the hit rate at position k for the intersecting labels. This function internally uses :func:`retrieval_hit_rate` to calculate the hit rate at position :math:`k` for each query and performs macro-averaging across multiple queries. @@ -366,8 +356,7 @@ def retrieval_precision( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the precision at position k. + r"""Calculate the precision at position k. Precision at position :math:`k` is calculated as: @@ -408,8 +397,7 @@ def retrieval_precision_intersecting( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the precision at position k for the intersecting labels. + r"""Calculate the precision at position k for the intersecting labels. Precision at position :math:`k` for intersecting labels is calculated as: @@ -456,8 +444,7 @@ def retrieval_precision_macro( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the precision at position k for the intersecting labels. + r"""Calculate the precision at position k for the intersecting labels. This function internally uses :func:`retrieval_precision` to calculate the precision at position :math:`k` for each query and performs macro-averaging across multiple queries. @@ -472,8 +459,7 @@ def retrieval_precision_macro( def _dcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float: - r""" - Calculate the Discounted Cumulative Gain (DCG) at position k. + r"""Calculate the Discounted Cumulative Gain (DCG) at position k. DCG is calculated as: @@ -495,8 +481,7 @@ def _dcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float: def _idcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float: - r""" - Calculate the Ideal Discounted Cumulative Gain (IDCG) at position k. + r"""Calculate the Ideal Discounted Cumulative Gain (IDCG) at position k. IDCG is the maximum possible DCG that can be achieved if the relevance scores are sorted in descending order. It is calculated as: @@ -519,8 +504,7 @@ def _idcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float: @ignore_oos def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float: - r""" - Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k. + r"""Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k. NDCG at position :math:`k` is calculated as: @@ -559,8 +543,7 @@ def retrieval_ndcg_intersecting( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k for the intersecting labels. + r"""Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k for the intersecting labels. NDCG at position :math:`k` for intersecting labels is calculated as: @@ -602,8 +585,7 @@ def retrieval_ndcg_macro( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k for the intersecting labels. + r"""Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k for the intersecting labels. This function calculates NDCG using :func:`retrieval_ndcg` and computes the macro-averaged score. @@ -617,8 +599,7 @@ def retrieval_ndcg_macro( @ignore_oos def retrieval_mrr(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float: - r""" - Calculate the Mean Reciprocal Rank (MRR) at position k. + r"""Calculate the Mean Reciprocal Rank (MRR) at position k. MRR is calculated as: @@ -656,8 +637,7 @@ def retrieval_mrr_intersecting( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the Mean Reciprocal Rank (MRR) at position k for the intersecting labels. + r"""Calculate the Mean Reciprocal Rank (MRR) at position k for the intersecting labels. MRR is calculated as: @@ -697,8 +677,7 @@ def retrieval_mrr_macro( candidates_labels: CANDIDATE_TYPE, k: int | None = None, ) -> float: - r""" - Calculate the Mean Reciprocal Rank (MRR) at position k for the intersecting labels. + r"""Calculate the Mean Reciprocal Rank (MRR) at position k for the intersecting labels. This function calculates MRR using :func:`retrieval_mrr` and computes the macro-averaged score. diff --git a/autointent/metrics/scoring.py b/autointent/metrics/scoring.py index ab5023be1..bc2cbefa5 100644 --- a/autointent/metrics/scoring.py +++ b/autointent/metrics/scoring.py @@ -15,17 +15,27 @@ class ScoringMetricFn(Protocol): - """Protocol for scoring metrics.""" + """Protocol for scoring metrics. + + Args: + labels: Ground truth labels for each utterance. + - multiclass case: list representing an array of shape (n_samples,) with integer values + - multilabel case: list representing a matrix of shape (n_samples, n_classes) with integer values + scores: For each utterance, this list contains scores for each of n_classes classes. + + Returns: + Score of the scoring metric. + """ def __call__(self, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - """ - Calculate scoring metric. + """Calculate scoring metric. + + Args: + labels: Ground truth labels for each utterance. + scores: Scores for each utterance. - :param labels: ground truth labels for each utterance - - multiclass case: list representing an array of shape `(n_samples,)` with integer values - - multilabel case: list representing a matrix of shape `(n_samples, n_classes)` with integer values - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + Returns: + Score of the scoring metric. """ ... @@ -44,15 +54,14 @@ def wrapper(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: @ignore_oos def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, eps: float = 1e-10) -> float: - r""" - Supports multiclass and multilabel cases. + r"""Calculate log likelihood score for multiclass and multilabel cases. Multiclass case: Mean negative cross-entropy for each utterance classification result: .. math:: - \frac{1}{\ell}\sum_{i=1}^{\ell}\log(s[y[i]]) + \\frac{1}{\\ell}\\sum_{i=1}^{\\ell}\\log(s[y[i]]) where ``s[y[i]]`` is the predicted score of the ``i``-th utterance having the ground truth label. @@ -61,14 +70,18 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, .. math:: - \frac{1}{\ell}\sum_{i=1}^\ell\sum_{c=1}^C\Big[y[i,c]\cdot\log(s[i,c])+(1-y[i,c])\cdot\log(1-s[i,c])\Big] + \\frac{1}{\\ell}\\sum_{i=1}^\\ell\\sum_{c=1}^C\\Big[y[i,c]\\cdot\\log(s[i,c])+(1-y[i,c])\\cdot\\log(1-s[i,c])\\Big] - where ``s[i,c]`` is the predicted score of the ``i``-th utterance having the ground truth label ``c``. + Args: + labels: Ground truth labels for each utterance. + scores: For each utterance, a list containing scores for each of n_classes classes. + eps: A small value to avoid division by zero. - :param labels: Ground truth labels for each utterance. - :param scores: For each utterance, a list containing scores for each of `n_classes` classes. - :param eps: A small value to avoid division by zero. - :return: Score of the scoring metric. + Returns: + Score of the scoring metric. + + Raises: + ValueError: If any scores are not in the range (0,1]. """ labels_array, scores_array = transform(labels, scores) scores_array[scores_array == 0] = eps @@ -85,26 +98,27 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, log_likelihood = labels_array * np.log(scores_array) + (1 - labels_array) * np.log(1 - scores_array) clipped_one = log_likelihood.clip(min=-100, max=100) res = clipped_one.mean() - # test produces different output return round(float(res), 6) @ignore_oos def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Supports multiclass and multilabel cases. + r"""Calculate ROC AUC score for multiclass and multilabel cases. - Macro averaged roc-auc for utterance classification task, i.e. + Macro averaged roc-auc for utterance classification task: .. math:: - \frac{1}{C}\sum_{k=1}^C ROCAUC(scores[:, k], labels[:, k]) + \\frac{1}{C}\\sum_{k=1}^C ROCAUC(scores[:, k], labels[:, k]) + + where ``C`` is the number of classes. - where ``C`` is the number of classes + Args: + labels: Ground truth labels for each utterance. + scores: For each utterance, scores for each of n_classes classes. - :param labels: ground truth labels for each utterance - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + Returns: + ROC AUC score. """ labels_, scores_ = transform(labels, scores) @@ -118,18 +132,19 @@ def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> flo def _calculate_decision_metric( func: DecisionMetricFn, labels: list[int] | list[list[int]], scores: SCORES_VALUE_TYPE ) -> float: - r""" - Calculate decision metric. - - This function applies the given decision metric function `func` to evaluate the decisions. - It transforms the inputs and computes decisions based on the input scores: - - For multiclass classification, decisions are generated using `np.argmax`. - - For multilabel classification, decisions are generated using a threshold of 0.5. - - :param func: decision metric function - :param labels: ground truth labels for each utterance - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + """Calculate decision metric. + + This function applies the given decision metric function to evaluate the decisions. + For multiclass classification, decisions are generated using np.argmax. + For multilabel classification, decisions are generated using a threshold of 0.5. + + Args: + func: Decision metric function. + labels: Ground truth labels for each utterance. + scores: For each utterance, scores for each of n_classes classes. + + Returns: + Score of the decision metric. """ if isinstance(labels[0], int): pred_labels = np.argmax(scores, axis=1).tolist() @@ -143,84 +158,85 @@ def _calculate_decision_metric( @ignore_oos def scoring_accuracy(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate accuracy for multiclass and multilabel classification. + """Calculate accuracy for multiclass and multilabel classification. - This function computes accuracy by using :func:`autointent.metrics.decision.decision_accuracy` - to evaluate decisions. + Uses decision_accuracy to evaluate decisions. - :param labels: ground truth labels for each utterance - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + Args: + labels: Ground truth labels for each utterance. + scores: For each utterance, scores for each of n_classes classes. + + Returns: + Classification accuracy score. """ return _calculate_decision_metric(decision_accuracy, labels, scores) @ignore_oos def scoring_f1(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate the F1 score for multiclass and multilabel classification. + """Calculate F1 score for multiclass and multilabel classification. + + Uses decision_f1 to evaluate decisions. - This function computes the F1 score by using :func:`autointent.metrics.decision.decision_f1` - to evaluate decisions. + Args: + labels: Ground truth labels for each sample. + scores: For each sample, scores for each of n_classes classes. - :param labels: Ground truth labels for each sample - :param scores: For each sample, this list contains scores for each of `n_classes` classes - :return: F1 score + Returns: + F1 score. """ return _calculate_decision_metric(decision_f1, labels, scores) @ignore_oos def scoring_precision(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate precision for multiclass and multilabel classification. + """Calculate precision for multiclass and multilabel classification. + + Uses decision_precision to evaluate decisions. - This function computes precision by using :func:`autointent.metrics.decision.decision_precision` - to evaluate decisions. + Args: + labels: Ground truth labels for each sample. + scores: For each sample, scores for each of n_classes classes. - :param labels: Ground truth labels for each sample - :param scores: For each sample, this list contains scores for each of `n_classes` classes - :return: Precision score + Returns: + Precision score. """ return _calculate_decision_metric(decision_precision, labels, scores) @ignore_oos def scoring_recall(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate recall for multiclass and multilabel classification. + """Calculate recall for multiclass and multilabel classification. - This function computes recall by using :func:`autointent.metrics.decision.decision_recall` - to evaluate decisions. + Uses decision_recall to evaluate decisions. - :param labels: Ground truth labels for each sample - :param scores: For each sample, this list contains scores for each of `n_classes` classes - :return: Recall score + Args: + labels: Ground truth labels for each sample. + scores: For each sample, scores for each of n_classes classes. + + Returns: + Recall score. """ return _calculate_decision_metric(decision_recall, labels, scores) @ignore_oos def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate the hit rate for multilabel classification. + r"""Calculate hit rate for multilabel classification. - The hit rate measures the fraction of cases where the top-ranked label is in the set + Hit rate measures the fraction of cases where the top-ranked label is in the set of true labels for the instance. .. math:: - \text{Hit Rate} = \frac{1}{N} \sum_{i=1}^N \mathbb{1}(y_{\text{top},i} \in y_{\text{true},i}) + \\text{Hit Rate} = \\frac{1}{N} \\sum_{i=1}^N \\mathbb{1}(y_{\\text{top},i} \\in y_{\\text{true},i}) - where: - - :math:`N` is the total number of instances, - - :math:`y_{\text{top},i}` is the top-ranked predicted label for instance :math:`i`, - - :math:`y_{\text{true},i}` is the set of ground truth labels for instance :math:`i`. + Args: + labels: Ground truth labels for each sample. + scores: For each sample, scores for each of n_classes classes. - :param labels: Ground truth labels for each sample - :param scores: For each sample, this list contains scores for each of `n_classes` classes - :return: Hit rate score + Returns: + Hit rate score. """ labels_, scores_ = transform(labels, scores) @@ -232,34 +248,17 @@ def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> fl @ignore_oos def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - """ - Supports multilabel classification. - - Evaluates how far we need, on average, to go down the list of labels in order to cover - all the proper labels of the instance. - - - The ideal value is 1 - - The worst value is 0 - - The result is equivalent to executing the following code: - - >>> def compute_rank_metric(): - ... import numpy as np - ... scores = np.array([[1, 2, 3]]) - ... labels = np.array([1, 0, 0]) - ... n_classes = scores.shape[1] - ... from scipy.stats import rankdata - ... int_ranks = rankdata(scores, axis=1) - ... filtered_ranks = int_ranks * labels - ... max_ranks = np.max(filtered_ranks, axis=1) - ... float_ranks = (max_ranks - 1) / (n_classes - 1) - ... return float(1 - np.mean(float_ranks)) - >>> print(f"{compute_rank_metric():.1f}") - 1.0 - - :param labels: ground truth labels for each utterance - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + """Calculate negative coverage for multilabel classification. + + Evaluates how far we need to go down the list of labels to cover all proper labels. + The ideal value is 1, the worst value is 0. + + Args: + labels: Ground truth labels for each utterance. + scores: For each utterance, scores for each of n_classes classes. + + Returns: + Negative coverage score. """ labels_, scores_ = transform(labels, scores) @@ -269,33 +268,33 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) - @ignore_oos def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - """ - Supports multilabel. + """Calculate negative ranking loss for multilabel classification. - Compute the average number of label pairs that are incorrectly ordered given y_score - weighted by the size of the label set and the number of labels not in the label set. + Computes average number of incorrectly ordered label pairs weighted by label set size. + The ideal value is 0. - the ideal value is 0 + Args: + labels: Ground truth labels for each utterance. + scores: For each utterance, scores for each of n_classes classes. - :param labels: ground truth labels for each utterance - :param scores: for each utterance, this list contains scores for each of `n_classes` classes - :return: Score of the scoring metric + Returns: + Negative ranking loss score. """ return float(-label_ranking_loss(labels, scores)) @ignore_oos def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: - r""" - Calculate the mean average precision (MAP) score for multilabel classification. + """Calculate mean average precision (MAP) score for multilabel classification. - The MAP score measures the precision at different levels of ranking, - averaged across all queries. The ideal value is 1, indicating perfect ranking, while the worst value is 0. + Measures precision at different ranking levels, averaged across all queries. + The ideal value is 1, indicating perfect ranking. - This function utilizes :func:`sklearn.metrics.label_ranking_average_precision_score` for computation. + Args: + labels: Ground truth labels for each sample. + scores: For each sample, scores for each of n_classes classes. - :param labels: ground truth labels for each sample - :param scores: for each sample, this list contains scores for each of `n_classes` classes - :return: mean average precision score + Returns: + Mean average precision score. """ return float(label_ranking_average_precision_score(labels, scores)) diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index be46c06a1..7cf4ae529 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -2,7 +2,7 @@ from typing import TypeVar -from .abc import BaseDecision, BaseEmbedding, BaseModule, BaseRegex, BaseScorer +from .base import BaseDecision, BaseEmbedding, BaseModule, BaseRegex, BaseScorer from .decision import ( AdaptiveDecision, ArgmaxDecision, diff --git a/autointent/modules/abc/_embedding.py b/autointent/modules/abc/_embedding.py deleted file mode 100644 index 9382a7565..000000000 --- a/autointent/modules/abc/_embedding.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Base class for embedding modules.""" - -from abc import ABC - -from autointent import Context -from autointent.custom_types import ListOfLabels -from autointent.modules.abc import BaseModule - - -class BaseEmbedding(BaseModule, ABC): - """Base class for embedding modules.""" - - def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]: - return (context.data_handler.train_utterances(0), context.data_handler.train_labels(0)) # type: ignore[return-value] diff --git a/autointent/modules/abc/__init__.py b/autointent/modules/base/__init__.py similarity index 100% rename from autointent/modules/abc/__init__.py rename to autointent/modules/base/__init__.py diff --git a/autointent/modules/abc/_base.py b/autointent/modules/base/_base.py similarity index 63% rename from autointent/modules/abc/_base.py rename to autointent/modules/base/_base.py index f44f53902..2614426f5 100644 --- a/autointent/modules/abc/_base.py +++ b/autointent/modules/base/_base.py @@ -8,6 +8,7 @@ import numpy as np import numpy.typing as npt +from typing_extensions import assert_never from autointent._dump_tools import Dumper from autointent.context import Context @@ -19,7 +20,14 @@ class BaseModule(ABC): - """Base module.""" + """Base module for all modules. + + Attributes: + supports_oos: Whether the module supports out-of-scope samples + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + name: Name of the module + """ supports_oos: bool supports_multilabel: bool @@ -28,27 +36,31 @@ class BaseModule(ABC): @abstractmethod def fit(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None: - """ - Fit the model. + """Fit the model. - :param args: Args to fit - :param kwargs: Kwargs to fit + Args: + *args: Args to fit + **kwargs: Kwargs to fit """ def score(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Calculate metric on test set and return metric value. + """Calculate metric on test set and return metric value. + + Args: + context: Context to score + metrics: Metrics to score - :param context: Context to score - :param metrics: Metrics to score - :return: Computed metrics value for the test set or error code of metrics + Returns: + Computed metrics value for the test set or error code of metrics + + Raises: + ValueError: If unknown scheme is provided """ if context.data_handler.config.scheme == "ho": return self.score_ho(context, metrics) if context.data_handler.config.scheme == "cv": return self.score_cv(context, metrics) - msg = f"Unknown scheme: {context.data_handler.config.scheme}" - raise ValueError(msg) + assert_never(context.data_handler.config.scheme) @abstractmethod def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: ... @@ -58,25 +70,29 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: .. @abstractmethod def get_assets(self) -> Artifact: - """Return useful assets that represent intermediate data into context.""" + """Return useful assets that represent intermediate data into context. + + Returns: + Artifact containing intermediate data + """ @abstractmethod def clear_cache(self) -> None: """Clear cache.""" def dump(self, path: str) -> None: - """ - Dump all data needed for inference. + """Dump all data needed for inference. - :param path: Path to dump + Args: + path: Path to dump """ Dumper.dump(self, Path(path)) def load(self, path: str) -> None: - """ - Load data from dump. + """Load data from dump. - :param path: Path to load + Args: + path: Path to load """ Dumper.load(self, Path(path)) @@ -84,11 +100,14 @@ def load(self, path: str) -> None: def predict( self, *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any] ) -> ListOfGenericLabels | npt.NDArray[Any]: - """ - Predict on the input. + """Predict on the input. - :param args: args to predict. - :param kwargs: kwargs to predict. + Args: + *args: args to predict + **kwargs: kwargs to predict + + Returns: + Predictions """ def predict_with_metadata( @@ -96,40 +115,48 @@ def predict_with_metadata( *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any], ) -> tuple[ListOfGenericLabels | npt.NDArray[Any], list[dict[str, Any]] | None]: - """ - Predict on the input with metadata. + """Predict on the input with metadata. - :param args: args to predict. - :param kwargs: kwargs to predict. + Args: + *args: args to predict + **kwargs: kwargs to predict + + Returns: + Tuple of predictions and metadata """ return self.predict(*args, **kwargs), None @classmethod @abstractmethod def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "BaseModule": - """ - Initialize self from context. + """Initialize self from context. - :param context: Context to init from. - :param kwargs: Additional kwargs. + Args: + context: Context to init from + **kwargs: Additional kwargs + + Returns: + Initialized module """ def get_embedder_config(self) -> dict[str, Any] | None: - """ - Get the config of the embedder. + """Get the config of the embedder. - :return: Embedder config. + Returns: + Embedder config if available, None otherwise """ return None @staticmethod def score_metrics_ho(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]: - """ - Score metrics on the test set. + """Score metrics on the test set. + + Args: + params: Params to score + metrics_dict: Dictionary of metrics to compute - :param params: Params to score - :param metrics_dict: - :return: + Returns: + Dictionary with computed metrics """ metrics = {} for metric_name, metric_fn in metrics_dict.items(): @@ -142,6 +169,16 @@ def score_metrics_cv( # type: ignore[no-untyped-def] cv_iterator: Iterable[tuple[list[str], ListOfLabels, list[str], ListOfLabels]], **fit_kwargs, # noqa: ANN003 ) -> tuple[dict[str, float], list[ListOfGenericLabels] | list[npt.NDArray[Any]]]: + """Score metrics using cross-validation. + + Args: + metrics_dict: Dictionary of metrics to compute + cv_iterator: Cross-validation iterator + **fit_kwargs: Additional arguments for fit method + + Returns: + Tuple of metrics dictionary and predictions + """ metrics_values: dict[str, list[float]] = {name: [] for name in metrics_dict} all_val_preds = [] @@ -156,6 +193,14 @@ def score_metrics_cv( # type: ignore[no-untyped-def] return metrics, all_val_preds # type: ignore[return-value] def _validate_multilabel(self, data_is_multilabel: bool) -> None: + """Validate if module supports the required classification type. + + Args: + data_is_multilabel: Whether the data is multilabel + + Raises: + WrongClassificationError: If module doesn't support the required classification type + """ if data_is_multilabel and not self.supports_multilabel: msg = f'"{self.name}" module is incompatible with multi-label classifiction.' logger.error(msg) @@ -166,6 +211,15 @@ def _validate_multilabel(self, data_is_multilabel: bool) -> None: raise WrongClassificationError(msg) def _validate_oos(self, data_contains_oos: bool, raise_error: bool = True) -> None: + """Validate if module supports out-of-scope samples. + + Args: + data_contains_oos: Whether data contains OOS samples + raise_error: Whether to raise error on validation failure + + Raises: + ValueError: If validation fails and raise_error is True + """ if data_contains_oos != self.supports_oos: if self.supports_oos and not data_contains_oos: msg = ( @@ -183,19 +237,27 @@ def _validate_oos(self, data_contains_oos: bool, raise_error: bool = True) -> No logger.warning(msg) def _validate_task(self, labels: ListOfGenericLabels) -> None: + """Validate task specifications. + + Args: + labels: Training labels + """ self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels) self._validate_multilabel(self._multilabel) self._validate_oos(self._oos) @staticmethod def _get_task_specs(labels: ListOfGenericLabels) -> tuple[int, bool, bool]: - """ - Infer number of classes, type of classification and whether data contains OOS samples. + """Infer number of classes, type of classification and whether data contains OOS samples. + + Args: + labels: Training labels - :param scores: training scores - :param labels: training labels - :return: number of classes, indicator if it's a multi-label task, - indicator if data contains oos samples + Returns: + Tuple containing: + - number of classes + - indicator if it's a multi-label task + - indicator if data contains oos samples """ contains_oos_samples = any(label is None for label in labels) in_domain_label = next(lab for lab in labels if lab is not None) diff --git a/autointent/modules/abc/_decision.py b/autointent/modules/base/_decision.py similarity index 70% rename from autointent/modules/abc/_decision.py rename to autointent/modules/base/_decision.py index f97541a29..e9487ed31 100644 --- a/autointent/modules/abc/_decision.py +++ b/autointent/modules/base/_decision.py @@ -1,16 +1,17 @@ -"""Predictior module.""" +"""Predictor module.""" from abc import ABC, abstractmethod from typing import Any, Literal import numpy as np import numpy.typing as npt +from typing_extensions import assert_never from autointent import Context from autointent.context.optimization_info import DecisionArtifact from autointent.custom_types import ListOfGenericLabels from autointent.metrics import DECISION_METRICS -from autointent.modules.abc import BaseModule +from autointent.modules.base import BaseModule from autointent.schemas import Tag @@ -24,29 +25,37 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Fit the model. + """Fit the model. - :param scores: Scores to fit - :param labels: Labels to fit - :param tags: Tags to fit + Args: + scores: Scores to fit + labels: Labels to fit + tags: Tags to fit """ @abstractmethod def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: - """ - Predict the best score. + """Predict the best score. + + Args: + scores: Scores to predict - :param scores: Scores to predict + Returns: + Predicted labels """ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Calculate metric on test set and return metric value. + """Calculate metric on test set and return metric value. + + Args: + context: Context to score + metrics: List of metrics to compute + + Returns: + Dictionary of computed metrics values for the test set - :param context: Context to score - :param split: Target split - :return: Computed metrics value for the test set or error code of metrics + Raises: + RuntimeError: If no folded scores are found """ train_scores, train_labels, tags = self.get_train_data(context) self.fit(train_scores, train_labels, tags) @@ -58,12 +67,17 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: return self.score_metrics_ho((val_labels, decisions), chosen_metrics) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Calculate metric on test set and return metric value. + """Calculate metric on test set and return metric value. + + Args: + context: Context to score + metrics: List of metrics to compute + + Returns: + Dictionary of computed metrics values for the test set - :param context: Context to score - :param split: Target split - :return: Computed metrics value for the test set or error code of metrics + Raises: + RuntimeError: If no folded scores are found """ labels = context.data_handler.train_labels_folded() scores = context.optimization_info.get_best_folded_scores() @@ -91,13 +105,26 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: return {name: float(np.mean(values_list)) for name, values_list in metrics_values.items()} def get_assets(self) -> DecisionArtifact: - """Return useful assets that represent intermediate data into context.""" + """Return useful assets that represent intermediate data into context. + + Returns: + Decision artifact containing intermediate data + """ return self._artifact def clear_cache(self) -> None: """Clear cache.""" def _validate_task(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) -> None: + """Validate task specifications. + + Args: + scores: Input scores + labels: Input labels + + Raises: + ValueError: If there is a mismatch between provided labels and scores + """ self._n_classes, self._multilabel, self._oos = self._get_task_specs(labels) self._validate_multilabel(self._multilabel) self._validate_oos(self._oos, raise_error=False) @@ -110,6 +137,14 @@ def _validate_task(self, scores: npt.NDArray[Any], labels: ListOfGenericLabels) raise ValueError(msg) def get_train_data(self, context: Context) -> tuple[npt.NDArray[Any], ListOfGenericLabels, list[Tag]]: + """Get training data from context. + + Args: + context: Context containing the data + + Returns: + Tuple containing scores, labels, and tags + """ labels, scores = get_decision_evaluation_data(context, "train") return (scores, labels, context.data_handler.tags) @@ -118,12 +153,17 @@ def get_decision_evaluation_data( context: Context, split: Literal["train", "validation"], ) -> tuple[ListOfGenericLabels, npt.NDArray[np.float64]]: - """ - Get decision evaluation data. + """Get decision evaluation data. + + Args: + context: Context containing the data + split: Target split (either 'train' or 'validation') - :param context: Context - :param split: Target split - :return: + Returns: + Tuple containing labels and scores for the specified split + + Raises: + ValueError: If invalid split name is provided or no scores are found """ if split == "train": labels = context.data_handler.train_labels(1) @@ -132,8 +172,7 @@ def get_decision_evaluation_data( labels = context.data_handler.validation_labels(1) scores = context.optimization_info.get_best_validation_scores() else: - message = f"Invalid split '{split}' provided. Expected one of 'train', 'validation'." - raise ValueError(message) + assert_never(split) if scores is None: message = f"No '{split}' scores found in the optimization info" diff --git a/autointent/modules/base/_embedding.py b/autointent/modules/base/_embedding.py new file mode 100644 index 000000000..376fe5a82 --- /dev/null +++ b/autointent/modules/base/_embedding.py @@ -0,0 +1,22 @@ +"""Base class for embedding modules.""" + +from abc import ABC + +from autointent import Context +from autointent.custom_types import ListOfLabels +from autointent.modules.base import BaseModule + + +class BaseEmbedding(BaseModule, ABC): + """Base class for embedding modules.""" + + def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]: + """Get train data. + + Args: + context: Context to get train data from + + Returns: + Tuple of train utterances and train labels + """ + return context.data_handler.train_utterances(0), context.data_handler.train_labels(0) # type: ignore[return-value] diff --git a/autointent/modules/abc/_regex.py b/autointent/modules/base/_regex.py similarity index 75% rename from autointent/modules/abc/_regex.py rename to autointent/modules/base/_regex.py index c693d8c0a..34656b799 100644 --- a/autointent/modules/abc/_regex.py +++ b/autointent/modules/base/_regex.py @@ -2,7 +2,7 @@ from abc import ABC -from autointent.modules.abc import BaseModule +from autointent.modules.base import BaseModule class BaseRegex(BaseModule, ABC): diff --git a/autointent/modules/abc/_scoring.py b/autointent/modules/base/_scoring.py similarity index 63% rename from autointent/modules/abc/_scoring.py rename to autointent/modules/base/_scoring.py index bdf4ae939..a5c33855d 100644 --- a/autointent/modules/abc/_scoring.py +++ b/autointent/modules/base/_scoring.py @@ -9,12 +9,11 @@ from autointent.context.optimization_info import ScorerArtifact from autointent.custom_types import ListOfLabels from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL -from autointent.modules.abc import BaseModule +from autointent.modules.base import BaseModule class BaseScorer(BaseModule, ABC): - """ - Abstract base class for scoring modules. + """Abstract base class for scoring modules. Scoring modules predict scores for utterances and evaluate their performance using a scoring metric. @@ -27,9 +26,25 @@ def fit( self, utterances: list[str], labels: ListOfLabels, - ) -> None: ... + ) -> None: + """Fit the scoring module to the training data. + + Args: + utterances: List of training utterances. + labels: List of training labels. + """ + ... def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: + """Evaluate the scorer on a test set and compute the specified metric. + + Args: + context: Context containing test set and other data. + metrics: List of metrics to compute. + + Returns: + Computed metrics value for the test set or error code of metrics. + """ self.fit(*self.get_train_data(context)) val_utterances = context.data_handler.validation_utterances(0) @@ -47,12 +62,15 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: return self.score_metrics_ho((val_labels, scores), chosen_metrics) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the scorer on a test set and compute the specified metric. + """Evaluate the scorer on a test set and compute the specified metric. + + Args: + context: Context containing test set and other data. + metrics: List of metrics to compute. + + Returns: + Computed metrics value for the test set or error code of metrics. - :param context: Context containing test set and other data. - :param split: Target split - :return: Computed metrics value for the test set or error code of metrics """ metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} @@ -66,21 +84,31 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: return metrics_calculated def get_assets(self) -> ScorerArtifact: - """ - Retrieve assets generated during scoring. + """Retrieve assets generated during scoring. - :return: ScorerArtifact containing test, validation and test scores. + Returns: + ScorerArtifact containing test, validation and test scores. """ return self._artifact def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]: + """Get train data. + + Args: + context: Context to get train data from + + Returns: + Tuple of train utterances and train labels + """ return context.data_handler.train_utterances(0), context.data_handler.train_labels(0) # type: ignore[return-value] @abstractmethod def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Predict scores for a list of utterances. + """Predict scores for a list of utterances. + + Args: + utterances: List of utterances to score. - :param utterances: List of utterances to score. - :return: Array of predicted scores. + Returns: + Array of predicted scores. """ diff --git a/autointent/modules/decision/_adaptive.py b/autointent/modules/decision/_adaptive.py index ef16cb6d2..e980a3a8b 100644 --- a/autointent/modules/decision/_adaptive.py +++ b/autointent/modules/decision/_adaptive.py @@ -10,7 +10,7 @@ from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabelsWithOOS, MultiLabel from autointent.exceptions import MismatchNumClassesError from autointent.metrics import decision_f1 -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from autointent.schemas import Tag from ._utils import apply_tags @@ -20,18 +20,21 @@ class AdaptiveDecision(BaseDecision): - """ - Decision for multi-label classification using adaptive thresholds. + """Decision for multi-label classification using adaptive thresholds. The AdaptiveDecision calculates optimal thresholds based on the given scores and labels, ensuring the best performance on multi-label data. - :ivar _n_classes: Number of classes in the dataset. - :ivar _r: Scaling factor for thresholds. - :ivar tags: List of Tag objects for mutually exclusive classes. - :ivar name: Name of the predictor, defaults to "adaptive". + Attributes: + _n_classes: Number of classes in the dataset + _r: Scaling factor for thresholds + tags: List of Tag objects for mutually exclusive classes + name: Name of the predictor, defaults to "adaptive" + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + supports_oos: Whether the module supports out-of-scope samples - Examples + Examples: -------- .. testcode:: @@ -59,11 +62,11 @@ class AdaptiveDecision(BaseDecision): name = "adaptive" def __init__(self, search_space: list[FloatFromZeroToOne] | None = None) -> None: - """ - Initialize the AdaptiveDecision. + """Initialize the AdaptiveDecision. - :param search_space: List of threshold scaling factors to search for optimal performance. - Defaults to a range between 0 and 1. + Args: + search_space: List of threshold scaling factors to search for optimal performance. + Defaults to a range between 0 and 1 """ self.search_space = search_space if search_space is not None else default_search_space @@ -73,12 +76,14 @@ def __init__(self, search_space: list[FloatFromZeroToOne] | None = None) -> None @classmethod def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "AdaptiveDecision": - """ - Create an AdaptiveDecision instance using a Context object. + """Create an AdaptiveDecision instance using a Context object. + + Args: + context: Context containing configurations and utilities + search_space: List of threshold scaling factors, or None for default - :param context: Context containing configurations and utilities. - :param search_space: List of threshold scaling factors, or None for default. - :return: Initialized AdaptiveDecision instance. + Returns: + Initialized AdaptiveDecision instance """ return cls( search_space=search_space, @@ -90,13 +95,15 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Fit the predictor by optimizing the threshold scaling factor. + """Fit the predictor by optimizing the threshold scaling factor. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + labels: List of true multi-label targets + tags: List of Tag objects for mutually exclusive classes, or None - :param scores: Array of shape (n_samples, n_classes) with predicted scores. - :param labels: List of true multi-label targets. - :param tags: List of Tag objects for mutually exclusive classes, or None. - :raises WrongClassificationError: If used on non-multi-label data. + Raises: + WrongClassificationError: If used on non-multi-label data """ self.tags = tags @@ -111,12 +118,16 @@ def fit( self._r = float(self.search_space[np.argmax(metrics_list)]) def predict(self, scores: npt.NDArray[Any]) -> ListOfLabelsWithOOS: - """ - Predict labels for the given scores. + """Predict labels for the given scores. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores - :param scores: Array of shape (n_samples, n_classes) with predicted scores. - :return: Array of shape (n_samples, n_classes) with predicted binary labels. - :raises MismatchNumClassesError: If the number of classes does not match the trained predictor. + Returns: + Array of shape (n_samples, n_classes) with predicted binary labels + + Raises: + MismatchNumClassesError: If the number of classes does not match the trained predictor """ if scores.shape[1] != self._n_classes: raise MismatchNumClassesError @@ -124,24 +135,28 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfLabelsWithOOS: def get_adapted_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]: - """ - Compute adaptive thresholds based on scaling factor and scores. + """Compute adaptive thresholds based on scaling factor and scores. - :param r: Scaling factor for thresholds. - :param scores: Array of shape (n_samples, n_classes) with predicted scores. - :return: Array of thresholds for each class and sample. + Args: + r: Scaling factor for thresholds + scores: Array of shape (n_samples, n_classes) with predicted scores + + Returns: + Array of thresholds for each class and sample """ return r * np.max(scores, axis=1) + (1 - r) * np.min(scores, axis=1) # type: ignore[no-any-return] def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | None) -> ListOfLabelsWithOOS: - """ - Predict binary labels for multi-label classification. + """Predict binary labels for multi-label classification. - :param scores: Array of shape (n_samples, n_classes) with predicted scores. - :param r: Scaling factor for thresholds. - :param tags: List of Tag objects for mutually exclusive classes, or None. - :return: Array of shape (n_samples, n_classes) with predicted binary labels. + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + r: Scaling factor for thresholds + tags: List of Tag objects for mutually exclusive classes, or None + + Returns: + Array of shape (n_samples, n_classes) with predicted binary labels """ thresh = get_adapted_threshes(r, scores) res = (scores >= thresh[:, None]).astype(int) @@ -152,11 +167,13 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non def multilabel_score(y_true: ListOfGenericLabels, y_pred: ListOfGenericLabels) -> float: - """ - Calculate the weighted F1 score for multi-label classification. + """Calculate the weighted F1 score for multi-label classification. + + Args: + y_true: List of true multi-label targets + y_pred: Array of shape (n_samples, n_classes) with predicted labels - :param y_true: List of true multi-label targets. - :param y_pred: Array of shape (n_samples, n_classes) with predicted labels. - :return: Weighted F1 score. + Returns: + Weighted F1 score """ return decision_f1(y_true, y_pred) diff --git a/autointent/modules/decision/_argmax.py b/autointent/modules/decision/_argmax.py index df36dde3e..526cb745a 100644 --- a/autointent/modules/decision/_argmax.py +++ b/autointent/modules/decision/_argmax.py @@ -9,22 +9,26 @@ from autointent import Context from autointent.custom_types import ListOfGenericLabels from autointent.exceptions import MismatchNumClassesError -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from autointent.schemas import Tag logger = logging.getLogger(__name__) class ArgmaxDecision(BaseDecision): - """ - Argmax decision module. + """Argmax decision module. The ArgmaxDecision is a simple predictor that selects the class with the highest score (argmax) for single-label classification tasks. - :ivar _n_classes: Number of classes in the dataset. + Attributes: + name: Name of the predictor, defaults to "argmax" + supports_oos: Whether the module supports out-of-scope samples + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + _n_classes: Number of classes in the dataset - Examples + Examples: -------- .. testcode:: @@ -51,14 +55,17 @@ class ArgmaxDecision(BaseDecision): _n_classes: int def __init__(self) -> None: - """Init.""" + """Initialize ArgmaxDecision.""" @classmethod def from_context(cls, context: Context) -> "ArgmaxDecision": - """ - Initialize form context. + """Initialize from context. - :param context: Context + Args: + context: Context object containing configurations and utilities + + Returns: + Initialized ArgmaxDecision instance """ return cls() @@ -68,22 +75,29 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Argmax not fitting anything. + """Fit the predictor (no-op for ArgmaxDecision). - :param scores: Scores to fit - :param labels: Labels to fit - :param tags: Tags to fit - :raises WrongClassificationError: If the classification is wrong. + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + labels: List of true labels + tags: List of Tag objects for mutually exclusive classes, or None + + Raises: + WrongClassificationError: If used on non-single-label data """ self._validate_task(scores, labels) def predict(self, scores: npt.NDArray[Any]) -> list[int]: - """ - Predict the argmax. + """Predict labels using argmax strategy. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + + Returns: + List of predicted class indices - :param scores: Scores to predict - :raises MismatchNumClassesError: If the number of classes is invalid. + Raises: + MismatchNumClassesError: If the number of classes does not match the trained predictor """ if scores.shape[1] != self._n_classes: raise MismatchNumClassesError diff --git a/autointent/modules/decision/_jinoos.py b/autointent/modules/decision/_jinoos.py index 08d965a06..e60248762 100644 --- a/autointent/modules/decision/_jinoos.py +++ b/autointent/modules/decision/_jinoos.py @@ -8,24 +8,27 @@ from autointent import Context from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels from autointent.exceptions import MismatchNumClassesError -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from autointent.schemas import Tag default_search_space = np.linspace(0, 1, num=100) class JinoosDecision(BaseDecision): - """ - Jinoos predictor module. + """Jinoos predictor module. JinoosDecision predicts the best scores for single-label classification tasks and detects out-of-scope (OOS) samples based on a threshold. - :ivar thresh: The optimized threshold value for OOS detection. - :ivar name: Name of the predictor, defaults to "adaptive". - :ivar _n_classes: Number of classes determined during fitting. + Attributes: + thresh: The optimized threshold value for OOS detection + name: Name of the predictor, defaults to "jinoos" + _n_classes: Number of classes determined during fitting + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + supports_oos: Whether the module supports out-of-scope samples - Examples + Examples: -------- .. testcode:: @@ -57,10 +60,10 @@ def __init__( self, search_space: list[FloatFromZeroToOne] | None = None, ) -> None: - """ - Initialize Jinoos predictor. + """Initialize Jinoos predictor. - :param search_space: Search space for threshold + Args: + search_space: List of threshold values to search through for OOS detection """ self.search_space = np.array(search_space) if search_space is not None else default_search_space @@ -70,11 +73,14 @@ def __init__( @classmethod def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "JinoosDecision": - """ - Initialize from context. + """Initialize from context. + + Args: + context: Context containing configurations and utilities + search_space: List of threshold values to search through - :param context: Context - :param search_space: Search space + Returns: + Initialized JinoosDecision instance """ return cls( search_space=search_space, @@ -86,12 +92,12 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Fit the model. + """Fit the model. - :param scores: Scores to fit - :param labels: Labels to fit - :param tags: Tags to fit + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + labels: List of true labels + tags: List of Tag objects for mutually exclusive classes, or None """ self._validate_task(scores, labels) @@ -106,10 +112,16 @@ def fit( self._thresh = float(self.search_space[np.argmax(metrics_list)]) def predict(self, scores: npt.NDArray[Any]) -> list[int | None]: - """ - Predict the best score. + """Predict the best score. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores - :param scores: Scores to predict + Returns: + List of predicted class indices or None for OOS samples + + Raises: + MismatchNumClassesError: If the number of classes does not match the trained predictor """ if scores.shape[1] != self._n_classes: raise MismatchNumClassesError @@ -119,18 +131,23 @@ def predict(self, scores: npt.NDArray[Any]) -> list[int | None]: @staticmethod def jinoos_score(y_true: ListOfGenericLabels, y_pred: npt.NDArray[Any]) -> float: - r""" - Calculate Jinoos score. + r"""Calculate Jinoos score. + + The score is calculated as: .. math:: \\frac{C_{in}}{N_{in}}+\\frac{C_{oos}}{N_{oos}} - where $C_{in}$ is the number of correctly predicted in-domain labels - and $N_{in}$ is the total number of in-domain labels. The same for OOS samples + where C_in is the number of correctly predicted in-domain labels + and N_in is the total number of in-domain labels. The same for OOS samples. + + Args: + y_true: True labels + y_pred: Predicted labels - :param y_true: True labels - :param y_pred: Predicted labels + Returns: + Combined accuracy score for in-domain and OOS samples """ y_true_, y_pred_ = np.array(y_true), np.array(y_pred) @@ -149,11 +166,15 @@ def jinoos_score(y_true: ListOfGenericLabels, y_pred: npt.NDArray[Any]) -> float def _predict(scores: npt.NDArray[np.float64]) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]: - """ - Predict the best score. + """Predict the best score. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores - :param scores: Scores to predict - :return: + Returns: + Tuple containing: + - Array of predicted class indices + - Array of highest scores for each sample """ pred_classes = np.argmax(scores, axis=1) best_scores = scores[np.arange(len(scores)), pred_classes] @@ -161,15 +182,15 @@ def _predict(scores: npt.NDArray[np.float64]) -> tuple[npt.NDArray[np.int64], np def _detect_oos(classes: npt.NDArray[Any], scores: npt.NDArray[Any], thresh: float) -> npt.NDArray[Any]: - """ - Detect out of scope samples. + """Detect out of scope samples. - OOS samples are marked with label -1. + Args: + classes: Array of predicted class indices + scores: Array of confidence scores + thresh: Threshold for OOS detection - :param classes: Classes to detect - :param scores: Scores to detect - :param thresh: Threshold to detect - :return: + Returns: + Array of predicted class indices with OOS samples marked as -1 """ classes[scores < thresh] = -1 # out of scope return classes diff --git a/autointent/modules/decision/_threshold.py b/autointent/modules/decision/_threshold.py index ece61bb7f..b0ccf82a5 100644 --- a/autointent/modules/decision/_threshold.py +++ b/autointent/modules/decision/_threshold.py @@ -1,4 +1,4 @@ -"""Threshold.""" +"""Threshold decision module.""" import logging from typing import Any @@ -9,7 +9,7 @@ from autointent import Context from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, MultiLabel from autointent.exceptions import MismatchNumClassesError -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from autointent.schemas import Tag from ._utils import apply_tags @@ -18,16 +18,21 @@ class ThresholdDecision(BaseDecision): - """ - Threshold predictor module. + """Threshold predictor module. ThresholdDecision uses a predefined threshold (or array of thresholds) to predict labels for single-label or multi-label classification tasks. - :ivar tags: Tags for predictions (if any). - :ivar name: Name of the predictor, defaults to "adaptive". + Attributes: + tags: Tags for predictions (if any) + name: Name of the predictor, defaults to "threshold" + supports_oos: Whether the module supports out-of-scope samples + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + _multilabel: Whether the task is multilabel + _n_classes: Number of classes in the dataset - Examples + Examples: -------- Single-label classification =========================== @@ -77,10 +82,10 @@ def __init__( self, thresh: FloatFromZeroToOne | list[FloatFromZeroToOne] = 0.5, ) -> None: - """ - Initialize threshold predictor. + """Initialize threshold predictor. - :param thresh: Threshold for the scores, shape (n_classes,) or float + Args: + thresh: Threshold for the scores, shape (n_classes,) or float """ val_error = False self.thresh = thresh if isinstance(thresh, float) else np.array(thresh) @@ -97,11 +102,14 @@ def __init__( def from_context( cls, context: Context, thresh: FloatFromZeroToOne | list[FloatFromZeroToOne] = 0.5 ) -> "ThresholdDecision": - """ - Initialize from context. + """Initialize from context. + + Args: + context: Context containing configurations and utilities + thresh: Threshold for classification - :param context: Context - :param thresh: Threshold + Returns: + Initialized ThresholdDecision instance """ return cls( thresh=thresh, @@ -113,12 +121,15 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Fit the model. + """Fit the model. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + labels: List of true labels + tags: List of Tag objects for mutually exclusive classes, or None - :param scores: Scores to fit - :param labels: Labels to fit - :param tags: Tags to fit + Raises: + MismatchNumClassesError: If number of thresholds doesn't match number of classes """ self.tags = tags self._validate_task(scores, labels) @@ -134,10 +145,16 @@ def fit( self.thresh = np.array(self.thresh) def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: - """ - Predict the best score. + """Predict labels using thresholds. - :param scores: Scores to predict + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + + Returns: + Predicted labels (either single-label or multi-label) + + Raises: + MismatchNumClassesError: If number of classes in scores doesn't match training data """ if scores.shape[1] != self._n_classes: msg = "Provided scores number don't match with number of classes which predictor was trained on." @@ -148,12 +165,14 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: def multiclass_predict(scores: npt.NDArray[Any], thresh: float | npt.NDArray[Any]) -> ListOfGenericLabels: - """ - Make predictions for multiclass classification task. + """Make predictions for multiclass classification task. - :param scores: Scores from the model, shape (n_samples, n_classes) - :param thresh: Threshold for the scores, shape (n_classes,) or float - :return: Predicted classes, shape (n_samples,) + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + thresh: Threshold for the scores, shape (n_classes,) or float + + Returns: + List of predicted class indices or None for OOS samples """ pred_classes: npt.NDArray[Any] = np.argmax(scores, axis=1) best_scores = scores[np.arange(len(scores)), pred_classes] @@ -173,13 +192,15 @@ def multilabel_predict( thresh: float | npt.NDArray[Any], tags: list[Tag] | None, ) -> ListOfGenericLabels: - """ - Make predictions for multilabel classification task. + """Make predictions for multilabel classification task. + + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + thresh: Threshold for the scores, shape (n_classes,) or float + tags: List of Tag objects for mutually exclusive classes, or None - :param scores: Scores from the model, shape (n_samples, n_classes) - :param thresh: Threshold for the scores, shape (n_classes,) or float - :param tags: Tags for predictions - :return: Multilabel prediction + Returns: + List of predicted multi-label targets or None for OOS samples """ res = (scores >= thresh).astype(int) if isinstance(thresh, float) else (scores >= thresh[None, :]).astype(int) if tags: diff --git a/autointent/modules/decision/_tunable.py b/autointent/modules/decision/_tunable.py index a2e43d03b..89ab24a11 100644 --- a/autointent/modules/decision/_tunable.py +++ b/autointent/modules/decision/_tunable.py @@ -12,7 +12,7 @@ from autointent.custom_types import ListOfGenericLabels from autointent.exceptions import MismatchNumClassesError from autointent.metrics import DECISION_METRICS, DecisionMetricFn -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from autointent.schemas import Tag from ._threshold import multiclass_predict, multilabel_predict @@ -21,18 +21,22 @@ class TunableDecision(BaseDecision): - """ - Tunable predictor module. + """Tunable predictor module. TunableDecision uses an optimization process to find the best thresholds for predicting labels in single-label or multi-label classification tasks. It is designed for datasets with varying score distributions and supports out-of-scope (OOS) detection. - :ivar name: Name of the predictor, defaults to "tunable". - :ivar _n_classes: Number of classes determined during fitting. - :ivar tags: Tags for predictions, if any. + Attributes: + name: Name of the predictor, defaults to "tunable" + _n_classes: Number of classes determined during fitting + _multilabel: Whether the task is multilabel + tags: Tags for predictions (if any) + supports_multilabel: Whether the module supports multilabel classification + supports_multiclass: Whether the module supports multiclass classification + supports_oos: Whether the module supports out-of-scope samples - Examples + Examples: -------- Single-label classification =========================== @@ -84,12 +88,13 @@ def __init__( seed: int = 0, tags: list[Tag] | None = None, ) -> None: - """ - Initialize tunable predictor. + """Initialize tunable predictor. - :param n_trials: Number of trials - :param seed: Seed - :param tags: Tags + Args: + target_metric: Metric to optimize during threshold tuning + n_optuna_trials: Number of optimization trials + seed: Random seed for reproducibility + tags: Tags for predictions (if any) """ self.target_metric = target_metric self.n_optuna_trials = n_optuna_trials @@ -108,11 +113,15 @@ def __init__( def from_context( cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320 ) -> "TunableDecision": - """ - Initialize from context. + """Initialize from context. - :param context: Context - :param n_trials: Number of trials + Args: + context: Context containing configurations and utilities + target_metric: Metric to optimize during threshold tuning + n_optuna_trials: Number of optimization trials + + Returns: + Initialized TunableDecision instance """ return cls( target_metric=target_metric, @@ -127,15 +136,15 @@ def fit( labels: ListOfGenericLabels, tags: list[Tag] | None = None, ) -> None: - """ - Fit module. + """Fit the predictor by optimizing thresholds. - When data doesn't contain out-of-scope utterances, using TunableDecision imposes unnecessary - computational overhead. + Note: When data doesn't contain out-of-scope utterances, using TunableDecision imposes + unnecessary computational overhead. - :param scores: Scores to fit - :param labels: Labels to fit - :param tags: Tags to fit + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + labels: List of true labels + tags: Tags for predictions (if any) """ self.tags = tags self._validate_task(scores, labels) @@ -155,10 +164,16 @@ def fit( self.thresh = thresh_optimizer.best_thresholds def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: - """ - Predict the best score. + """Predict labels using optimized thresholds. - :param scores: Scores to predict + Args: + scores: Array of shape (n_samples, n_classes) with predicted scores + + Returns: + Predicted labels (either single-label or multi-label) + + Raises: + MismatchNumClassesError: If number of classes in scores doesn't match training data """ if scores.shape[1] != self._n_classes: msg = "Provided scores number don't match with number of classes which predictor was trained on." @@ -169,17 +184,18 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: class ThreshOptimizer: - """Threshold optimizer.""" + """Threshold optimizer using Optuna for hyperparameter tuning.""" def __init__( self, metric_fn: DecisionMetricFn, n_classes: int, multilabel: bool, n_trials: int | None = None ) -> None: - """ - Initialize threshold optimizer. + """Initialize threshold optimizer. - :param n_classes: Number of classes - :param multilabel: Is multilabel - :param n_trials: Number of trials + Args: + metric_fn: Metric function for optimization + n_classes: Number of classes in the dataset + multilabel: Whether the task is multilabel + n_trials: Number of optimization trials (defaults to n_classes * 10) """ self.metric_fn = metric_fn self.n_classes = n_classes @@ -187,10 +203,13 @@ def __init__( self.n_trials = n_trials if n_trials is not None else n_classes * 10 def objective(self, trial: Trial) -> float: - """ - Objective function to optimize. + """Objective function to optimize. + + Args: + trial: Optuna trial object - :param trial: Trial + Returns: + Metric value for the current thresholds """ thresholds = np.array([trial.suggest_float(f"threshold_{i}", 0.0, 1.0) for i in range(self.n_classes)]) if self.multilabel: @@ -206,13 +225,13 @@ def fit( seed: int, tags: list[Tag] | None = None, ) -> None: - """ - Fit the optimizer. + """Fit the optimizer by finding optimal thresholds. - :param probas: Probabilities - :param labels: Labels - :param seed: Seed - :param tags: Tags + Args: + probas: Array of shape (n_samples, n_classes) with predicted probabilities + labels: List of true labels + seed: Random seed for reproducibility + tags: Tags for predictions (if any) """ self.probas = probas self.labels = labels diff --git a/autointent/modules/decision/_utils.py b/autointent/modules/decision/_utils.py index bce1a72f8..d10789144 100644 --- a/autointent/modules/decision/_utils.py +++ b/autointent/modules/decision/_utils.py @@ -1,4 +1,4 @@ -"""Utility functions and custom exceptions for handling multilabel predictions and errors.""" +"""Utility functions for handling multilabel predictions.""" from typing import Any @@ -9,17 +9,30 @@ def apply_tags(labels: npt.NDArray[Any], scores: npt.NDArray[Any], tags: list[Tag]) -> npt.NDArray[Any]: - """ - Adjust multilabel predictions based on intent class tags. + """Adjust multilabel predictions based on intent class tags. If some intent classes share a common tag (i.e., they are mutually exclusive) and are assigned to the same sample, this function retains only the class with the highest score among those with the shared tag. - :param labels: Array of shape (n_samples, n_classes) with binary labels (0 or 1). - :param scores: Array of shape (n_samples, n_classes) with float values (0 to 1). - :param tags: List of `Tag` objects, where each tag specifies mutually exclusive intent IDs. - :return: Adjusted array of shape (n_samples, n_classes) with binary labels. + Args: + labels: Array of shape (n_samples, n_classes) with binary labels (0 or 1) + scores: Array of shape (n_samples, n_classes) with float values (0 to 1) + tags: List of Tag objects, where each tag specifies mutually exclusive intent IDs + + Returns: + Array of shape (n_samples, n_classes) with adjusted binary labels + + Examples: + >>> import numpy as np + >>> from autointent.schemas import Tag + >>> labels = np.array([[1, 1, 0], [1, 1, 1]]) + >>> scores = np.array([[0.8, 0.6, 0.3], [0.7, 0.9, 0.5]]) + >>> tags = [Tag(name="group1", intent_ids=[0, 1])] + >>> adjusted = apply_tags(labels, scores, tags) + >>> print(adjusted) + [[1 0 0] + [0 1 1]] """ labels = labels.copy() diff --git a/autointent/modules/embedding/_logreg.py b/autointent/modules/embedding/_logreg.py index 5be87fff2..f55ad983c 100644 --- a/autointent/modules/embedding/_logreg.py +++ b/autointent/modules/embedding/_logreg.py @@ -1,4 +1,4 @@ -"""LogregAimedEmbedding class for a proxy optimzation of embedding.""" +"""LogregAimedEmbedding class for a proxy optimization of embedding.""" from typing import Any @@ -14,21 +14,24 @@ from autointent.context.optimization_info import EmbeddingArtifact from autointent.custom_types import ListOfLabels from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL -from autointent.modules.abc import BaseEmbedding +from autointent.modules.base import BaseEmbedding class LogregAimedEmbedding(BaseEmbedding): - r""" - Module for configuring embeddings optimized for linear classification. + """Module for configuring embeddings optimized for linear classification. The main purpose of this module is to be used at embedding node for optimizing embedding configuration using its logreg classification quality as a sort of proxy metric. - :ivar _classifier: The trained logistic regression model. - :ivar _label_encoder: Label encoder for converting labels to numerical format. - :ivar name: Name of the module, defaults to "logreg". + Attributes: + _classifier: The trained logistic regression model + _label_encoder: Label encoder for converting labels to numerical format + name: Name of the module, defaults to "logreg" + supports_multiclass: Whether the module supports multiclass classification + supports_multilabel: Whether the module supports multilabel classification + supports_oos: Whether the module supports out-of-scope detection - Examples + Examples: -------- .. testcode:: @@ -54,11 +57,11 @@ def __init__( embedder_config: EmbedderConfig | str | dict[str, Any], cv: PositiveInt = 3, ) -> None: - """ - Initialize the LogregAimedEmbedding. + """Initialize the LogregAimedEmbedding. - :param embedder_config: Config of the embedder used for creating embeddings. - :param cv: the number of folds used in LogisticRegressionCV + Args: + embedder_config: Config of the embedder used for creating embeddings + cv: Number of folds used in LogisticRegressionCV """ self.embedder_config = EmbedderConfig.from_search_config(embedder_config) self.cv = cv @@ -74,13 +77,15 @@ def from_context( embedder_config: EmbedderConfig | str, cv: PositiveInt = 3, ) -> "LogregAimedEmbedding": - """ - Create a LogregAimedEmbedding instance using a Context object. + """Create a LogregAimedEmbedding instance using a Context object. + + Args: + context: Context containing configurations and utilities + cv: Number of folds used in LogisticRegressionCV + embedder_config: Config of the embedder to use - :param context: The context containing configurations and utilities. - :param cv: the number of folds used in LogisticRegressionCV - :param embedder_config: Config of the embedder to use. - :return: Initialized LogregAimedEmbedding instance. + Returns: + Initialized LogregAimedEmbedding instance """ return cls( cv=cv, @@ -88,14 +93,15 @@ def from_context( ) def clear_cache(self) -> None: + """Clear embedder from memory.""" self._embedder.clear_ram() def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Train the logistic regression model using the provided utterances and labels. + """Train the logistic regression model using the provided utterances and labels. - :param utterances: List of text data to index. - :param labels: List of corresponding labels for the utterances. + Args: + utterances: List of text data to index + labels: List of corresponding labels for the utterances """ if hasattr(self, "_embedder"): self.clear_cache() @@ -119,11 +125,14 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: self._classifier.fit(embeddings, labels) def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the embedding model using a specified metric function. + """Evaluate the embedding model using specified metric functions. + + Args: + context: Context containing test data and labels + metrics: List of metric names to compute - :param context: The context containing test data and labels. - :return: Computed metrics value for the test set or error code of metrics + Returns: + Dictionary of computed metric values for the test set """ train_utterances, train_labels = self.get_train_data(context) self.fit(train_utterances, train_labels) @@ -138,11 +147,14 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: return self.score_metrics_ho((val_labels, probas), chosen_metrics) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the embedding model using a specified metric function. + """Evaluate the embedding model using specified metric functions. + + Args: + context: Context containing test data and labels + metrics: List of metric names to compute - :param context: The context containing test data and labels. - :return: Computed metrics value for the test set or error code of metrics + Returns: + Dictionary of computed metric values for the test set """ metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} @@ -151,14 +163,22 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: return metrics_calculated def get_assets(self) -> EmbeddingArtifact: - """ - Get the classifier artifacts for this module. + """Get the classifier artifacts for this module. - :return: A EmbeddingArtifact object containing embedder information. + Returns: + EmbeddingArtifact object containing embedder information """ return EmbeddingArtifact(config=self.embedder_config) def predict(self, utterances: list[str]) -> NDArray[np.float64]: + """Predict probabilities for input utterances. + + Args: + utterances: List of texts to predict probabilities for + + Returns: + Array of predicted probabilities + """ embeddings = self._embedder.embed(utterances, TaskTypeEnum.classification) probas = self._classifier.predict_proba(embeddings) diff --git a/autointent/modules/embedding/_retrieval.py b/autointent/modules/embedding/_retrieval.py index c71d85309..a0fdfc665 100644 --- a/autointent/modules/embedding/_retrieval.py +++ b/autointent/modules/embedding/_retrieval.py @@ -9,20 +9,23 @@ from autointent.context.optimization_info import EmbeddingArtifact from autointent.custom_types import ListOfLabels from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL -from autointent.modules.abc import BaseEmbedding +from autointent.modules.base import BaseEmbedding class RetrievalAimedEmbedding(BaseEmbedding): - r""" - Module for configuring embeddings optimized for retrieval tasks. + """Module for configuring embeddings optimized for retrieval tasks. The main purpose of this module is to be used at embedding node for optimizing embedding configuration using its retrieval quality as a sort of proxy metric. - :ivar _vector_index: The vector index used for nearest neighbor retrieval. - :ivar name: Name of the module, defaults to "retrieval". + Attributes: + _vector_index: The vector index used for nearest neighbor retrieval + name: Name of the module, defaults to "retrieval" + supports_multiclass: Whether the module supports multiclass classification + supports_multilabel: Whether the module supports multilabel classification + supports_oos: Whether the module supports out-of-scope detection - Examples + Examples: -------- .. testcode:: @@ -49,11 +52,11 @@ def __init__( embedder_config: EmbedderConfig | str | dict[str, Any], k: PositiveInt = 10, ) -> None: - """ - Initialize the RetrievalAimedEmbedding. + """Initialize the RetrievalAimedEmbedding. - :param k: Number of nearest neighbors to retrieve. - :param embedder_config: Config of the embedder used for creating embeddings. + Args: + k: Number of nearest neighbors to retrieve + embedder_config: Config of the embedder used for creating embeddings """ self.k = k embedder_config = EmbedderConfig.from_search_config(embedder_config) @@ -70,13 +73,15 @@ def from_context( embedder_config: EmbedderConfig | str, k: PositiveInt = 10, ) -> "RetrievalAimedEmbedding": - """ - Create an instance using a Context object. + """Create an instance using a Context object. + + Args: + context: The context containing configurations and utilities + k: Number of nearest neighbors to retrieve + embedder_config: Config of the embedder to use - :param context: The context containing configurations and utilities. - :param k: Number of nearest neighbors to retrieve. - :param embedder_config: Config of the embedder to use. - :return: Initialized RetrievalAimedEmbedding instance. + Returns: + Initialized RetrievalAimedEmbedding instance """ return cls( k=k, @@ -84,11 +89,11 @@ def from_context( ) def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Fit the vector index using the provided utterances and labels. + """Fit the vector index using the provided utterances and labels. - :param utterances: List of text data to index. - :param labels: List of corresponding labels for the utterances. + Args: + utterances: List of text data to index + labels: List of corresponding labels for the utterances """ if hasattr(self, "_vector_index"): self.clear_cache() @@ -101,11 +106,14 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: self._vector_index.add(utterances, labels) def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the embedding model using a specified metric function. + """Evaluate the embedding model using specified metric functions. + + Args: + context: Context containing test data and labels + metrics: List of metric names to compute - :param context: The context containing test data and labels. - :return: Computed metrics value for the test set or error code of metrics + Returns: + Dictionary of computed metric values for the test set """ train_utterances, train_labels = self.get_train_data(context) self.fit(train_utterances, train_labels) @@ -119,6 +127,15 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: return self.score_metrics_ho((val_labels, predictions), chosen_metrics) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: + """Evaluate the embedding model using specified metric functions. + + Args: + context: Context containing test data and labels + metrics: List of metric names to compute + + Returns: + Dictionary of computed metric values for the test set + """ metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} @@ -126,10 +143,10 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: return metrics_calculated def get_assets(self) -> EmbeddingArtifact: - """ - Get the retriever artifacts for this module. + """Get the retriever artifacts for this module. - :return: A EmbeddingArtifact object containing embedder information. + Returns: + A EmbeddingArtifact object containing embedder information """ return EmbeddingArtifact(config=self.embedder_config) @@ -138,11 +155,13 @@ def clear_cache(self) -> None: self._vector_index.clear_ram() def predict(self, utterances: list[str]) -> list[ListOfLabels]: - """ - Predict the nearest neighbors for a list of utterances. + """Predict the nearest neighbors for a list of utterances. + + Args: + utterances: List of utterances for which nearest neighbors are to be retrieved - :param utterances: List of utterances for which nearest neighbors are to be retrieved. - :return: List of labels for each retrieved utterance. + Returns: + List of labels for each retrieved utterance """ predictions, _, _ = self._vector_index.query(utterances, self.k) return predictions diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index e7b2e5304..76d33ca11 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -8,36 +8,53 @@ from autointent.context.optimization_info import Artifact from autointent.custom_types import LabelType from autointent.metrics import REGEX_METRICS -from autointent.modules.abc import BaseRegex +from autointent.modules.base import BaseRegex from autointent.schemas import Intent class RegexPatternsCompiled(TypedDict): - """Compiled regex patterns.""" + """Compiled regex patterns. + + Attributes: + id: Intent ID + regex_full_match: Compiled regex patterns for full match + regex_partial_match: Compiled regex patterns for partial match + """ id: int - """Intent ID.""" regex_full_match: list[re.Pattern[str]] - """Compiled regex patterns for full match.""" regex_partial_match: list[re.Pattern[str]] - """Compiled regex patterns for partial match.""" class Regex(BaseRegex): - """Regular expressions based intent detection module.""" + """Regular expressions based intent detection module. + + A module that uses regular expressions to detect intents in text utterances. + Supports both full and partial pattern matching. + + Attributes: + name: Name of the module, defaults to "regex" + """ name = "regex" @classmethod def from_context(cls, context: Context) -> "Regex": - """Initialize from context.""" + """Initialize from context. + + Args: + context: Context object containing configuration + + Returns: + Initialized Regex instance + """ return cls() def fit(self, intents: list[Intent]) -> None: - """ - Fit the model. + """Fit the model with intent patterns. - :param intents: Intents to fit + Args: + intents: List of intents to fit the model with """ self.regex_patterns = [ RegexPatterns( @@ -50,10 +67,13 @@ def fit(self, intents: list[Intent]) -> None: self._compile_regex_patterns() def predict(self, utterances: list[str]) -> list[LabelType]: - """ - Predict intents for utterances. + """Predict intents for given utterances. + + Args: + utterances: List of utterances to predict intents for - :param utterances: Utterances to predict + Returns: + List of predicted intent labels """ return [self._predict_single(utterance)[0] for utterance in utterances] @@ -61,10 +81,15 @@ def predict_with_metadata( self, utterances: list[str], ) -> tuple[list[LabelType], list[dict[str, Any]] | None]: - """ - Predict intents for utterances with metadata. + """Predict intents for utterances with pattern matching metadata. + + Args: + utterances: List of utterances to predict intents for - :param utterances: Utterances to predict + Returns: + Tuple containing: + - List of predicted intent labels + - List of pattern matching metadata for each utterance """ predictions, metadata = [], [] for utterance in utterances: @@ -74,11 +99,14 @@ def predict_with_metadata( return predictions, metadata def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[str, list[str]]: - """ - Match utterance with intent record. + """Match utterance with intent record patterns. + + Args: + utterance: Utterance to match + intent_record: Intent record containing patterns to match against - :param utterance: Utterance to match - :param intent_record: Intent record to match + Returns: + Dictionary containing matched full and partial patterns """ full_matches = [ pattern.pattern for pattern in intent_record["regex_full_match"] if pattern.fullmatch(utterance) is not None @@ -89,12 +117,16 @@ def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[s return {"full_matches": full_matches, "partial_matches": partial_matches} def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str]]]: - """ - Predict intent for a single utterance. + """Predict intent for a single utterance. - :param utterance: Utterance to predict + Args: + utterance: Utterance to predict intent for + + Returns: + Tuple containing: + - Predicted intent labels + - Dictionary of matched patterns """ - # todo test this prediction = set() matches: dict[str, list[str]] = {"full_matches": [], "partial_matches": []} for intent_record in self.regex_patterns_compiled: @@ -106,6 +138,15 @@ def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str return list(prediction), matches def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: + """Score the model using holdout validation. + + Args: + context: Context containing validation data + metrics: List of metric names to compute + + Returns: + Dictionary of computed metric values + """ self.fit(context.data_handler.dataset.intents) val_utterances = context.data_handler.validation_utterances(0) @@ -117,12 +158,14 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: return self.score_metrics_ho((val_labels, pred_labels), chosen_metrics) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the scorer on a test set and compute the specified metric. + """Score the model using cross-validation. + + Args: + context: Context containing validation data + metrics: List of metric names to compute - :param context: Context containing test set and other data. - :param split: Target split - :return: Computed metrics value for the test set or error code of metrics + Returns: + Dictionary of computed metric values """ chosen_metrics = {name: fn for name, fn in REGEX_METRICS.items() if name in metrics} @@ -131,15 +174,19 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: return metrics_calculated def clear_cache(self) -> None: - """Clear cache.""" + """Clear cached regex patterns.""" del self.regex_patterns def get_assets(self) -> Artifact: - """Get assets.""" + """Get model assets. + + Returns: + Empty Artifact object + """ return Artifact() def _compile_regex_patterns(self) -> None: - """Compile regex patterns.""" + """Compile regex patterns with case-insensitive flag.""" self.regex_patterns_compiled = [ RegexPatternsCompiled( id=regex_patterns["id"], diff --git a/autointent/modules/scoring/_description/description.py b/autointent/modules/scoring/_description/description.py index 3ccc52f04..f52e05076 100644 --- a/autointent/modules/scoring/_description/description.py +++ b/autointent/modules/scoring/_description/description.py @@ -13,19 +13,23 @@ from autointent.context.optimization_info import ScorerArtifact from autointent.custom_types import ListOfLabels from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer class DescriptionScorer(BaseScorer): - r""" - Scoring module that scores utterances based on similarity to intent descriptions. - - DescriptionScorer embeds both the utterances and the intent descriptions, then computes a similarity score - between the two, using either cosine similarity and softmax. - - :ivar _embedder: The embedder used to generate embeddings for utterances and descriptions. - :ivar name: Name of the scorer, defaults to "description". - + """Scoring module that scores utterances based on similarity to intent descriptions. + + DescriptionScorer embeds both the utterances and the intent descriptions, then computes + a similarity score between the two, using either cosine similarity and softmax. + + Attributes: + _embedder: The embedder used to generate embeddings for utterances and descriptions + name: Name of the scorer, defaults to "description" + _n_classes: Number of intent classes + _multilabel: Whether the task is multilabel + _description_vectors: Embedded vectors of intent descriptions + supports_multiclass: Whether multiclass classification is supported + supports_multilabel: Whether multilabel classification is supported """ _embedder: Embedder @@ -41,11 +45,11 @@ def __init__( embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, temperature: PositiveFloat = 1.0, ) -> None: - """ - Initialize the DescriptionScorer. + """Initialize the DescriptionScorer. - :param embedder_config: Config of the embedder model. - :param temperature: Temperature parameter for scaling logits, defaults to 1.0. + Args: + embedder_config: Config of the embedder model + temperature: Temperature parameter for scaling logits, defaults to 1.0 """ self.temperature = temperature self.embedder_config = EmbedderConfig.from_search_config(embedder_config) @@ -61,13 +65,15 @@ def from_context( temperature: PositiveFloat, embedder_config: EmbedderConfig | str | None = None, ) -> "DescriptionScorer": - """ - Create a DescriptionScorer instance using a Context object. + """Create a DescriptionScorer instance using a Context object. + + Args: + context: Context containing configurations and utilities + temperature: Temperature parameter for scaling logits + embedder_config: Config of the embedder model. If None, the best embedder is used - :param context: Context containing configurations and utilities. - :param temperature: Temperature parameter for scaling logits. - :param embedder_config: Config of the embedder model. If None, the best embedder is used. - :return: Initialized DescriptionScorer instance. + Returns: + Initialized DescriptionScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -78,10 +84,10 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - """ - Get the name of the embedder. + """Get the name of the embedder. - :return: Embedder name. + Returns: + Embedder name """ return self.embedder_config.model_dump() @@ -91,13 +97,15 @@ def fit( labels: ListOfLabels, descriptions: list[str], ) -> None: - """ - Fit the scorer by embedding utterances and descriptions. + """Fit the scorer by embedding utterances and descriptions. - :param utterances: List of utterances to embed. - :param labels: List of labels corresponding to the utterances. - :param descriptions: List of intent descriptions. - :raises ValueError: If descriptions contain None values or embeddings mismatch utterances. + Args: + utterances: List of utterances to embed + labels: List of labels corresponding to the utterances + descriptions: List of intent descriptions + + Raises: + ValueError: If descriptions contain None values or embeddings mismatch utterances """ if hasattr(self, "_embedder"): self._embedder.clear_ram() @@ -117,11 +125,13 @@ def fit( self._embedder = embedder def predict(self, utterances: list[str]) -> NDArray[np.float64]: - """ - Predict scores for utterances based on similarity to intent descriptions. + """Predict scores for utterances based on similarity to intent descriptions. - :param utterances: List of utterances to score. - :return: Array of probabilities for each utterance. + Args: + utterances: List of utterances to score + + Returns: + Array of probabilities for each utterance """ utterance_vectors = self._embedder.embed(utterances, TaskTypeEnum.sts) similarities: NDArray[np.float64] = cosine_similarity(utterance_vectors, self._description_vectors) @@ -137,6 +147,14 @@ def clear_cache(self) -> None: self._embedder.clear_ram() def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels, list[str]]: + """Get training data from context. + + Args: + context: Context containing training data + + Returns: + Tuple containing utterances, labels, and descriptions + """ return ( # type: ignore[return-value] context.data_handler.train_utterances(0), context.data_handler.train_labels(0), @@ -144,12 +162,14 @@ def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels, lis ) def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: - """ - Evaluate the scorer on a test set and compute the specified metric. + """Evaluate the scorer on a test set and compute the specified metrics. + + Args: + context: Context containing test set and other data + metrics: List of metric names to compute - :param context: Context containing test set and other data. - :param metrics: List of metric names to compute. - :return: Computed metrics value for the test set or error code of metrics + Returns: + Dictionary of computed metric values """ metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} diff --git a/autointent/modules/scoring/_dnnc/dnnc.py b/autointent/modules/scoring/_dnnc/dnnc.py index fa759ea89..a947da09a 100644 --- a/autointent/modules/scoring/_dnnc/dnnc.py +++ b/autointent/modules/scoring/_dnnc/dnnc.py @@ -11,37 +11,31 @@ from autointent import Context, Ranker, VectorIndex from autointent.configs import CrossEncoderConfig, EmbedderConfig from autointent.custom_types import ListOfLabels -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer logger = logging.getLogger(__name__) class DNNCScorer(BaseScorer): - r""" - Scoring module for intent classification using a discriminative nearest neighbor classification (DNNC). + """Scoring module for intent classification using discriminative nearest neighbor classification. This module uses a Ranker for scoring candidate intents and can optionally train a logistic regression head on top of cross-encoder features. - .. code-block:: bibtex - - @misc{zhang2020discriminativenearestneighborfewshot, - title={Discriminative Nearest Neighbor Few-Shot Intent Detection by Transferring Natural Language Inference}, - author={Jian-Guo Zhang and Kazuma Hashimoto and Wenhao Liu and Chien-Sheng Wu and Yao Wan and - Philip S. Yu and Richard Socher and Caiming Xiong}, - year={2020}, - eprint={2010.13009}, - archivePrefix={arXiv}, - primaryClass={cs.CL}, - url={https://arxiv.org/abs/2010.13009}, - } - - :ivar crossencoder_subdir: Subdirectory for storing the cross-encoder model (`Ranker`). - :ivar model: The model used for scoring, which could be a `Ranker` or a `CrossEncoderWithLogreg`. - :ivar _db_dir: Path to the database directory where the vector index is stored. - :ivar name: Name of the scorer, defaults to "dnnc". - - Examples + Reference: + Zhang, J. G., Hashimoto, K., Liu, W., Wu, C. S., Wan, Y., Yu, P. S., ... & Xiong, C. (2020). + Discriminative Nearest Neighbor Few-Shot Intent Detection by Transferring Natural Language Inference. + arXiv preprint arXiv:2010.13009. + + Attributes: + _n_classes: Number of intent classes + _vector_index: Index for nearest neighbor search + _cross_encoder: Ranker model for scoring pairs + name: Name of the scorer, defaults to "dnnc" + supports_multilabel: Whether multilabel classification is supported + supports_multiclass: Whether multiclass classification is supported + + Examples: -------- .. testcode:: @@ -81,12 +75,12 @@ def __init__( cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, ) -> None: - """ - Initialize the DNNCScorer. + """Initialize the DNNCScorer. - :param cross_encoder_config: Config of the cross-encoder model. - :param embedder_config: Config of the embedder model. - :param k: Number of nearest neighbors to retrieve. + Args: + cross_encoder_config: Config of the cross-encoder model + embedder_config: Config of the embedder model + k: Number of nearest neighbors to retrieve """ self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config) self.embedder_config = EmbedderConfig.from_search_config(embedder_config) @@ -104,14 +98,16 @@ def from_context( cross_encoder_config: CrossEncoderConfig | str | None = None, embedder_config: EmbedderConfig | str | None = None, ) -> "DNNCScorer": - """ - Create a DNNCScorer instance using a Context object. + """Create a DNNCScorer instance using a Context object. + + Args: + context: Context containing configurations and utilities + cross_encoder_config: Config of the cross-encoder model + k: Number of nearest neighbors to retrieve + embedder_config: Config of the embedder model, or None to use the best embedder - :param context: Context containing configurations and utilities. - :param cross_encoder_config: Config of the cross-encoder model. - :param k: Number of nearest neighbors to retrieve. - :param embedder_config: Config of the embedder model, or None to use the best embedder. - :return: Initialized DNNCScorer instance. + Returns: + Initialized DNNCScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -126,12 +122,14 @@ def from_context( ) def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Fit the scorer by training or loading the vector index and optionally training a logistic regression head. + """Fit the scorer by training or loading the vector index. + + Args: + utterances: List of training utterances + labels: List of labels corresponding to the utterances - :param utterances: List of training utterances. - :param labels: List of labels corresponding to the utterances. - :raises ValueError: If the vector index mismatches the provided utterances. + Raises: + ValueError: If the vector index mismatches the provided utterances """ if hasattr(self, "_vector_index"): self.clear_cache() @@ -145,20 +143,26 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: self._cross_encoder.fit(utterances, labels) def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Predict class scores for the given utterances. + """Predict class scores for the given utterances. - :param utterances: List of utterances to score. - :return: Array of predicted scores. + Args: + utterances: List of utterances to score + + Returns: + Array of predicted scores """ return self._predict(utterances)[0] def predict_with_metadata(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]: - """ - Predict class scores along with metadata for the given utterances. + """Predict class scores along with metadata for the given utterances. - :param utterances: List of utterances to score. - :return: Tuple of scores and metadata containing neighbor details and scores. + Args: + utterances: List of utterances to score + + Returns: + Tuple containing: + - Array of predicted scores + - List of metadata with neighbor details and scores """ scores, neighbors, neighbors_scores = self._predict(utterances) metadata = [ @@ -168,13 +172,17 @@ def predict_with_metadata(self, utterances: list[str]) -> tuple[npt.NDArray[Any] return scores, metadata def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list[str]]) -> list[list[float]]: - """ - Compute cross-encoder scores for utterances against their candidate neighbors. + """Compute cross-encoder scores for utterances against their candidate neighbors. + + Args: + utterances: List of query utterances + candidates: List of candidate utterances for each query - :param utterances: List of query utterances. - :param candidates: List of candidate utterances for each query. - :return: List of cross-encoder scores for each query-candidate pair. - :raises ValueError: If the number of utterances and candidates do not match. + Returns: + List of cross-encoder scores for each query-candidate pair + + Raises: + ValueError: If the number of utterances and candidates do not match """ if len(utterances) != len(candidates): msg = "Number of utterances doesn't match number of retrieved candidates" @@ -197,13 +205,14 @@ def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list ] def _build_result(self, scores: list[list[float]], labels: list[ListOfLabels]) -> npt.NDArray[Any]: - """ - Build a result matrix with scores assigned to the best neighbor's class. + """Build a result matrix with scores assigned to the best neighbor's class. - :param scores: for each query utterance, cross encoder scores of its k closest utterances - :param labels: corresponding intent labels + Args: + scores: Cross encoder scores for each query's k closest utterances + labels: Corresponding intent labels - :return: (n_queries, n_classes) matrix with zeros everywhere except the class of the best neighbor utterance + Returns: + Matrix of shape (n_queries, n_classes) with zeros everywhere except the class of the best neighbor """ return build_result(np.array(scores), np.array(labels), self._n_classes) @@ -212,11 +221,16 @@ def clear_cache(self) -> None: self._vector_index.clear_ram() def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]: - """ - Predict class scores for the given utterances using the vector index and cross-encoder. + """Predict class scores using vector index and cross-encoder. + + Args: + utterances: List of query utterances - :param utterances: List of query utterances. - :return: Tuple containing class scores, neighbor utterances, and neighbor scores. + Returns: + Tuple containing: + - Class scores matrix + - List of neighbor utterances + - List of neighbor scores """ labels, _, neighbors = self._vector_index.query( utterances, @@ -229,13 +243,15 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s def build_result(scores: npt.NDArray[Any], labels: npt.NDArray[Any], n_classes: int) -> npt.NDArray[Any]: - """ - Build a result matrix with scores assigned to the best neighbor's class. + """Build a result matrix with scores assigned to the best neighbor's class. + + Args: + scores: Cross-encoder scores for each query's neighbors + labels: Labels corresponding to each neighbor + n_classes: Total number of classes - :param scores: Cross-encoder scores for each query's neighbors. - :param labels: Labels corresponding to each neighbor. - :param n_classes: Total number of classes. - :return: Matrix of size (n_queries, n_classes) with scores for the best class. + Returns: + Matrix of shape (n_queries, n_classes) with scores for the best class """ res = np.zeros((len(scores), n_classes)) best_neighbors = np.argmax(scores, axis=1) diff --git a/autointent/modules/scoring/_knn/count_neighbors.py b/autointent/modules/scoring/_knn/count_neighbors.py index fd1504978..463fae2cf 100644 --- a/autointent/modules/scoring/_knn/count_neighbors.py +++ b/autointent/modules/scoring/_knn/count_neighbors.py @@ -5,14 +5,17 @@ def get_counts(labels: NDArray[np.int_], n_classes: int, weights: NDArray[np.float64]) -> NDArray[np.int64]: - """ - Get counts of labels in candidates for multiclass classification. + """Get counts of labels in candidates for multiclass classification. + + Args: + labels: np.ndarray of shape (n_samples, n_candidates) with integer labels from `[0,n_classes-1]` + n_classes: number of classes + weights: np.ndarray of shape (n_samples, n_candidates) with integer labels from `[0,n_classes-1]` + + Returns: + np.ndarray of shape (n_samples, n_classes) with statistics of how many times each class + label occured in candidates - :param labels: np.ndarray of shape (n_samples, n_candidates) with integer labels from `[0,n_classes-1]` - :param n_classes: number of classes - :param weights: np.ndarray of shape (n_samples, n_candidates) with integer labels from `[0,n_classes-1]` - :return: np.ndarray of shape (n_samples, n_classes) with statistics of how many times each class - label occured in candidates """ n_queries = labels.shape[0] labels += n_classes * np.arange(n_queries)[:, None] @@ -23,8 +26,7 @@ def get_counts(labels: NDArray[np.int_], n_classes: int, weights: NDArray[np.flo def get_counts_multilabel(labels: NDArray[np.int_], weights: NDArray[np.float64]) -> NDArray[np.float64]: - """ - Get counts of labels in candidates for multilabel classification. + """Get counts of labels in candidates for multilabel classification. :param labels: np.ndarray of shape (n_samples, n_candidates, n_classes) with binary labels :param weights: np.ndarray of shape (n_samples, n_candidates) with float weights diff --git a/autointent/modules/scoring/_knn/knn.py b/autointent/modules/scoring/_knn/knn.py index adf107212..b92a3eb44 100644 --- a/autointent/modules/scoring/_knn/knn.py +++ b/autointent/modules/scoring/_knn/knn.py @@ -9,23 +9,25 @@ from autointent import Context, VectorIndex from autointent.configs import EmbedderConfig from autointent.custom_types import WEIGHT_TYPES, ListOfLabels -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer from .weighting import apply_weights class KNNScorer(BaseScorer): - """ - K-nearest neighbors (KNN) scorer for intent classification. + """K-nearest neighbors (KNN) scorer for intent classification. This module uses a vector index to retrieve nearest neighbors for query utterances and applies a weighting strategy to compute class probabilities. - :ivar weights: Weighting strategy used for scoring. - :ivar _vector_index: VectorIndex instance for neighbor retrieval. - :ivar name: Name of the scorer, defaults to "knn". + Attributes: + weights: Weighting strategy used for scoring + _vector_index: VectorIndex instance for neighbor retrieval + name: Name of the scorer, defaults to "knn" + supports_multiclass: Whether multiclass classification is supported + supports_multilabel: Whether multilabel classification is supported - Examples + Examples: -------- .. testcode:: @@ -62,15 +64,15 @@ def __init__( embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, weights: WEIGHT_TYPES = "distance", ) -> None: - """ - Initialize the KNNScorer. - - :param embedder_config: Config of the embedder used for vectorization. - :param k: Number of closest neighbors to consider during inference. - :param weights: Weighting strategy: - - "uniform": Equal weight for all neighbors. - - "distance": Weight inversely proportional to distance. - - "closest": Only the closest neighbor of each class is weighted. + """Initialize the KNNScorer. + + Args: + embedder_config: Config of the embedder used for vectorization + k: Number of closest neighbors to consider during inference + weights: Weighting strategy: + - "uniform": Equal weight for all neighbors + - "distance": Weight inversely proportional to distance + - "closest": Only the closest neighbor of each class is weighted """ self.embedder_config = EmbedderConfig.from_search_config(embedder_config) self.k = k @@ -92,14 +94,16 @@ def from_context( weights: WEIGHT_TYPES = "distance", embedder_config: EmbedderConfig | str | None = None, ) -> "KNNScorer": - """ - Create a KNNScorer instance using a Context object. + """Create a KNNScorer instance using a Context object. - :param context: Context containing configurations and utilities. - :param k: Number of closest neighbors to consider during inference. - :param weights: Weighting strategy for scoring. - :param embedder_config: Config of the embedder, or None to use the best embedder. - :return: Initialized KNNScorer instance. + Args: + context: Context containing configurations and utilities + k: Number of closest neighbors to consider during inference + weights: Weighting strategy for scoring + embedder_config: Config of the embedder, or None to use the best embedder + + Returns: + Initialized KNNScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -111,20 +115,23 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - """ - Get the name of the embedder. + """Get the name of the embedder. - :return: Embedder name. + Returns: + Embedder name """ return self.embedder_config.model_dump() def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None: - """ - Fit the scorer by training or loading the vector index. + """Fit the scorer by training or loading the vector index. + + Args: + utterances: List of training utterances + labels: List of labels corresponding to the utterances + clear_cache: Whether to clear the vector index cache before fitting - :param utterances: List of training utterances. - :param labels: List of labels corresponding to the utterances. - :raises ValueError: If the vector index mismatches the provided utterances. + Raises: + ValueError: If the vector index mismatches the provided utterances """ if hasattr(self, "_vector_index") and clear_cache: self.clear_cache() @@ -135,20 +142,26 @@ def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = F self._vector_index.add(utterances, labels) def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Predict class probabilities for the given utterances. + """Predict class probabilities for the given utterances. - :param utterances: List of query utterances. - :return: Array of predicted probabilities for each class. + Args: + utterances: List of query utterances + + Returns: + Array of predicted probabilities for each class """ return self._predict(utterances)[0] def predict_with_metadata(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]: - """ - Predict class probabilities along with metadata for the given utterances. + """Predict class probabilities along with metadata for the given utterances. - :param utterances: List of query utterances. - :return: Tuple of predicted probabilities and metadata with neighbor information. + Args: + utterances: List of query utterances + + Returns: + Tuple containing: + - Array of predicted probabilities + - List of metadata with neighbor information """ scores, neighbors = self._predict(utterances) metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors] @@ -159,17 +172,41 @@ def clear_cache(self) -> None: self._vector_index.clear_ram() def _get_neighbours(self, utterances: list[str]) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]: + """Get nearest neighbors for given utterances. + + Args: + utterances: List of query utterances + + Returns: + Tuple containing: + - List of labels for neighbors + - List of distances to neighbors + - List of neighbor utterances + """ return self._vector_index.query(utterances, self.k) def _count_scores(self, labels: npt.NDArray[Any], distances: npt.NDArray[Any]) -> npt.NDArray[Any]: + """Calculate weighted scores for labels based on distances. + + Args: + labels: Array of neighbor labels + distances: Array of distances to neighbors + + Returns: + Array of weighted scores + """ return apply_weights(labels, distances, self.weights, self._n_classes, self._multilabel) def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]: - """ - Predict class probabilities and retrieve neighbors for the given utterances. + """Predict class probabilities and retrieve neighbors for the given utterances. + + Args: + utterances: List of query utterances - :param utterances: List of query utterances. - :return: Tuple containing class probabilities and neighbor utterances. + Returns: + Tuple containing: + - Array of class probabilities + - List of neighbor utterances """ labels, distances, neighbors = self._get_neighbours(utterances) scores = self._count_scores(np.array(labels), np.array(distances)) diff --git a/autointent/modules/scoring/_knn/rerank_scorer.py b/autointent/modules/scoring/_knn/rerank_scorer.py index 0ee9d3bf5..8a9fc95b9 100644 --- a/autointent/modules/scoring/_knn/rerank_scorer.py +++ b/autointent/modules/scoring/_knn/rerank_scorer.py @@ -14,13 +14,13 @@ class RerankScorer(KNNScorer): - """ - Re-ranking scorer using a cross-encoder for intent classification. + """Re-ranking scorer using a cross-encoder for intent classification. This module uses a cross-encoder to re-rank the nearest neighbors retrieved by a KNN scorer. - :ivar name: Name of the scorer, defaults to "rerank". - :ivar _scorer: Ranker instance for re-ranking. + Attributes: + name: Name of the scorer, defaults to "rerank" + _scorer: Ranker instance for re-ranking """ name = "rerank" @@ -35,18 +35,18 @@ def __init__( cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, ) -> None: - """ - Initialize the RerankScorer. - - :param embedder_config: Config of the embedder used for vectorization. - :param k: Number of closest neighbors to consider during inference. - :param weights: Weighting strategy: - - "uniform": Equal weight for all neighbors. - - "distance": Weight inversely proportional to distance. - - "closest": Only the closest neighbor of each class is weighted. - :param cross_encoder_config: Config of the cross-encoder model used for re-ranking. - :param m: Number of top-ranked neighbors to consider, or None to use k. - :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None. + """Initialize the RerankScorer. + + Args: + embedder_config: Config of the embedder used for vectorization + k: Number of closest neighbors to consider during inference + weights: Weighting strategy: + - "uniform": Equal weight for all neighbors + - "distance": Weight inversely proportional to distance + - "closest": Only the closest neighbor of each class is weighted + cross_encoder_config: Config of the cross-encoder model used for re-ranking + m: Number of top-ranked neighbors to consider, or None to use k + rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None """ super().__init__( embedder_config=embedder_config, @@ -80,18 +80,20 @@ def from_context( embedder_config: EmbedderConfig | str | None = None, rank_threshold_cutoff: int | None = None, ) -> "RerankScorer": - """ - Create a RerankScorer instance from a given context. - - :param context: Context object containing optimization information and vector index client. - :param k: Number of closest neighbors to consider during inference. - :param weights: Weighting strategy. - :param cross_encoder_config: Config of the cross-encoder model used for re-ranking. - :param embedder_config: Config of the embedder used for vectorization, - or None to use the best existing embedder. - :param m: Number of top-ranked neighbors to consider, or None to use k. - :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None. - :return: An instance of RerankScorer. + """Create a RerankScorer instance from a given context. + + Args: + context: Context object containing optimization information and vector index client + k: Number of closest neighbors to consider during inference + weights: Weighting strategy + cross_encoder_config: Config of the cross-encoder model used for re-ranking + embedder_config: Config of the embedder used for vectorization, + or None to use the best existing embedder + m: Number of top-ranked neighbors to consider, or None to use k + rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None + + Returns: + An instance of RerankScorer """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -109,11 +111,11 @@ def from_context( ) def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Fit the RerankScorer with utterances and labels. + """Fit the RerankScorer with utterances and labels. - :param utterances: List of utterances to fit the scorer. - :param labels: List of labels corresponding to the utterances. + Args: + utterances: List of utterances to fit the scorer + labels: List of labels corresponding to the utterances """ if hasattr(self, "_scorer"): self.clear_cache() @@ -126,15 +128,20 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: super().fit(utterances, labels, clear_cache=False) def clear_cache(self) -> None: + """Clear cached data in memory used by the scorer and vector index.""" self._scorer.clear_ram() super().clear_cache() def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]: - """ - Predict the scores and neighbors for given utterances. + """Predict the scores and neighbors for given utterances. + + Args: + utterances: List of utterances to predict scores for - :param utterances: List of utterances to predict scores for. - :return: A tuple containing the scores and neighbors. + Returns: + Tuple containing: + - Array of predicted scores + - List of neighbor utterances """ knn_labels, knn_distances, knn_neighbors = self._get_neighbours(utterances) diff --git a/autointent/modules/scoring/_knn/weighting.py b/autointent/modules/scoring/_knn/weighting.py index 730e3dceb..4704d4520 100644 --- a/autointent/modules/scoring/_knn/weighting.py +++ b/autointent/modules/scoring/_knn/weighting.py @@ -17,17 +17,18 @@ def apply_weights( n_classes: int, multilabel: bool, ) -> NDArray[Any]: - """ - Calculate probabilities based on labels, distances, and weighting strategy. - - :param labels: - - For multiclass: Array of shape (n_samples, n_neighbors) with integer labels in [0, n_classes - 1]. - - For multilabel: Array of shape (n_samples, n_neighbors, n_classes) with binary labels. - :param distances: Array of shape (n_samples, n_neighbors) with float distances. - :param weights: Weighting strategy to apply. Options are "closest", "uniform", or "distance". - :param n_classes: Number of classes in the dataset. - :param multilabel: Whether the task is multilabel classification. - :return: Array of shape (n_samples, n_classes) with calculated probabilities. + """Apply weighting strategy to calculate probabilities. + + Args: + labels: Array of shape (n_samples, n_neighbors) with integer labels in [0, n_classes - 1] for multiclass + or shape (n_samples, n_neighbors, n_classes) with binary labels for multilabel. + distances: Array of shape (n_samples, n_neighbors) with float distances. + weights: Weighting strategy to apply. Options are "closest", "uniform", or "distance". + n_classes: Number of classes in the dataset. + multilabel: Whether the task is multilabel classification. + + Returns: + Array of shape (n_samples, n_classes) with calculated probabilities. """ n_samples, n_candidates = distances.shape @@ -51,16 +52,17 @@ def apply_weights( def closest_weighting(labels: NDArray[Any], distances: NDArray[Any], multilabel: bool, n_classes: int) -> NDArray[Any]: - """ - Apply closest weighting strategy. - - :param labels: - - For multiclass: Array of shape (n_samples, n_neighbors) with integer labels in [0, n_classes - 1]. - - For multilabel: Array of shape (n_samples, n_neighbors, n_classes) with binary labels. - :param distances: Array of shape (n_samples, n_neighbors) with cosine distances. - :param multilabel: Whether the task is multilabel classification. - :param n_classes: Number of classes in the dataset. - :return: Array of shape (n_samples, n_classes) with calculated probabilities. + """Apply closest weighting strategy. + + Args: + labels: Array of shape (n_samples, n_neighbors) with integer labels in [0, n_classes - 1] for multiclass + or shape (n_samples, n_neighbors, n_classes) with binary labels for multilabel. + distances: Array of shape (n_samples, n_neighbors) with float distances. + multilabel: Whether the task is multilabel classification. + n_classes: Number of classes in the dataset. + + Returns: + Array of shape (n_samples, n_classes) with calculated probabilities. """ if not multilabel: labels = to_onehot(labels, n_classes) @@ -68,12 +70,14 @@ def closest_weighting(labels: NDArray[Any], distances: NDArray[Any], multilabel: def _closest_weighting(labels: NDArray[Any], distances: NDArray[Any]) -> NDArray[Any]: - """ - Apply closest weighting strategy for multilabel classification. + """Apply closest weighting strategy for multilabel classification. + + Args: + labels: Array of shape (n_samples, n_candidates, n_classes) with binary labels. + distances: Array of shape (n_samples, n_candidates) with cosine distances. - :param labels: Array of shape (n_samples, n_candidates, n_classes) with binary labels. - :param distances: Array of shape (n_samples, n_candidates) with cosine distances. - :return: Array of shape (n_samples, n_classes) with calculated probabilities. + Returns: + Array of shape (n_samples, n_classes) with calculated probabilities. """ # Broadcast to (n_samples, n_candidates, n_classes) broadcasted_similarities = np.broadcast_to(1 - distances[..., None], shape=labels.shape) @@ -85,12 +89,14 @@ def _closest_weighting(labels: NDArray[Any], distances: NDArray[Any]) -> NDArray def to_onehot(labels: NDArray[Any], n_classes: int) -> NDArray[Any]: - """ - Convert an array of integer labels to a one-hot encoded array. + """Convert an array of integer labels to a one-hot encoded array. + + Args: + labels: Array of shape (n_samples, n_neighbors) with integer labels. + n_classes: Number of classes in the dataset. - :param labels: Array of shape (n_samples, n_neighbors) with integer labels. - :param n_classes: Number of classes in the dataset. - :return: One-hot encoded array of shape (n_samples, n_neighbors, n_classes). + Returns: + One-hot encoded array of shape (n_samples, n_neighbors, n_classes). """ new_shape = (*labels.shape, n_classes) onehot_labels = np.zeros(shape=new_shape) diff --git a/autointent/modules/scoring/_linear.py b/autointent/modules/scoring/_linear.py index 4a43851c1..86a004bff 100644 --- a/autointent/modules/scoring/_linear.py +++ b/autointent/modules/scoring/_linear.py @@ -10,19 +10,24 @@ from autointent import Context, Embedder from autointent.configs import EmbedderConfig, TaskTypeEnum from autointent.custom_types import ListOfLabels -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer class LinearScorer(BaseScorer): - """ - Scoring module for linear classification using logistic regression. + """Scoring module for linear classification using logistic regression. This module uses embeddings generated from a transformer model to train a logistic regression classifier for intent classification. - :ivar name: Name of the scorer, defaults to "linear". + Attributes: + name: Name of the scorer, defaults to "linear" + _multilabel: Whether multilabel classification is used + _clf: Trained classifier instance + _embedder: Embedder instance for feature extraction + supports_multiclass: Whether multiclass classification is supported + supports_multilabel: Whether multilabel classification is supported - Example + Example: -------- .. testcode:: @@ -57,13 +62,13 @@ def __init__( cv: int = 3, seed: int = 0, ) -> None: - """ - Initialize the LinearScorer. + """Initialize the LinearScorer. - :param embedder_config: Config of the embedder model. - :param cv: Number of cross-validation folds, defaults to 3. - :param n_jobs: Number of parallel jobs for cross-validation, defaults to -1 (all CPUs). - :param seed: Random seed for reproducibility, defaults to 0. + Args: + embedder_config: Config of the embedder model + cv: Number of cross-validation folds, defaults to 3 + n_jobs: Number of parallel jobs for cross-validation, defaults to None + seed: Random seed for reproducibility, defaults to 0 """ self.cv = cv self.seed = seed @@ -79,12 +84,14 @@ def from_context( context: Context, embedder_config: EmbedderConfig | str | None = None, ) -> "LinearScorer": - """ - Create a LinearScorer instance using a Context object. + """Create a LinearScorer instance using a Context object. - :param context: Context containing configurations and utilities. - :param embedder_config: Config of the embedder, or None to use the best embedder. - :return: Initialized LinearScorer instance. + Args: + context: Context containing configurations and utilities + embedder_config: Config of the embedder, or None to use the best embedder + + Returns: + Initialized LinearScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -94,10 +101,10 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - """ - Get the name of the embedder. + """Get the name of the embedder. - :return: Embedder name. + Returns: + Embedder name """ return self.embedder_config.model_dump() @@ -106,12 +113,14 @@ def fit( utterances: list[str], labels: ListOfLabels, ) -> None: - """ - Train the logistic regression classifier. + """Train the logistic regression classifier. + + Args: + utterances: List of training utterances + labels: List of labels corresponding to the utterances - :param utterances: List of training utterances. - :param labels: List of labels corresponding to the utterances. - :raises ValueError: If the vector index mismatches the provided utterances. + Raises: + ValueError: If the vector index mismatches the provided utterances """ if hasattr(self, "_clf"): self.clear_cache() @@ -135,11 +144,13 @@ def fit( self._embedder = embedder def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Predict probabilities for the given utterances. + """Predict probabilities for the given utterances. + + Args: + utterances: List of query utterances - :param utterances: List of query utterances. - :return: Array of predicted probabilities for each class. + Returns: + Array of predicted probabilities for each class """ features = self._embedder.embed(utterances, TaskTypeEnum.classification) probas = self._clf.predict_proba(features) diff --git a/autointent/modules/scoring/_mlknn/mlknn.py b/autointent/modules/scoring/_mlknn/mlknn.py index fe42efcda..79b517c9f 100644 --- a/autointent/modules/scoring/_mlknn/mlknn.py +++ b/autointent/modules/scoring/_mlknn/mlknn.py @@ -10,19 +10,29 @@ from autointent import Context, VectorIndex from autointent.configs import EmbedderConfig from autointent.custom_types import ListOfLabels -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer class MLKnnScorer(BaseScorer): - """ - Multi-label k-nearest neighbors (ML-KNN) scorer. + """Multi-label k-nearest neighbors (ML-KNN) scorer. This module implements ML-KNN, a multi-label classifier that computes probabilities based on the k-nearest neighbors of a query instance. - :ivar name: Name of the scorer, defaults to "mlknn". - - Example + Attributes: + name: Name of the scorer, defaults to "mlknn" + _n_classes: Number of classes + _vector_index: Index for nearest neighbor search + _prior_prob_true: Prior probabilities for true labels + _prior_prob_false: Prior probabilities for false labels + _cond_prob_true: Conditional probabilities for true labels + _cond_prob_false: Conditional probabilities for false labels + _features: Embedded features of training data + _labels: Labels of training data + supports_multiclass: Whether multiclass classification is supported + supports_multilabel: Whether multilabel classification is supported + + Example: -------- .. testcode:: @@ -65,13 +75,13 @@ def __init__( s: float = 1.0, ignore_first_neighbours: int = 0, ) -> None: - """ - Initialize the MLKnnScorer. + """Initialize the MLKnnScorer. - :param k: Number of nearest neighbors to consider. - :param embedder_config: Config of the embedder used for vectorization. - :param s: Smoothing parameter for probability calculations, defaults to 1.0. - :param ignore_first_neighbours: Number of closest neighbors to ignore, defaults to 0. + Args: + k: Number of nearest neighbors to consider + embedder_config: Config of the embedder used for vectorization + s: Smoothing parameter for probability calculations, defaults to 1.0 + ignore_first_neighbours: Number of closest neighbors to ignore, defaults to 0 """ self.k = k self.embedder_config = EmbedderConfig.from_search_config(embedder_config) @@ -94,15 +104,17 @@ def from_context( ignore_first_neighbours: NonNegativeInt = 0, embedder_config: EmbedderConfig | str | None = None, ) -> "MLKnnScorer": - """ - Create an MLKnnScorer instance using a Context object. - - :param context: Context containing configurations and utilities. - :param k: Number of nearest neighbors to consider. - :param s: Smoothing parameter for probability calculations, defaults to 1.0. - :param ignore_first_neighbours: Number of closest neighbors to ignore, defaults to 0. - :param embedder_config: Config of the embedder, or None to use the best embedder. - :return: Initialized MLKnnScorer instance. + """Create an MLKnnScorer instance using a Context object. + + Args: + context: Context containing configurations and utilities + k: Number of nearest neighbors to consider + s: Smoothing parameter for probability calculations, defaults to 1.0 + ignore_first_neighbours: Number of closest neighbors to ignore, defaults to 0 + embedder_config: Config of the embedder, or None to use the best embedder + + Returns: + Initialized MLKnnScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -115,21 +127,23 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - """ - Get the name of the embedder. + """Get the name of the embedder. - :return: Embedder name. + Returns: + Embedder name """ return self.embedder_config.model_dump() def fit(self, utterances: list[str], labels: ListOfLabels) -> None: - """ - Fit the scorer by training or loading the vector index and calculating probabilities. + """Fit the scorer by training or loading the vector index and calculating probabilities. + + Args: + utterances: List of training utterances + labels: List of multi-label targets for each utterance - :param utterances: List of training utterances. - :param labels: List of multi-label targets for each utterance. - :raises TypeError: If the labels are not multi-label. - :raises ValueError: If the vector index mismatches the provided utterances. + Raises: + TypeError: If the labels are not multi-label + ValueError: If the vector index mismatches the provided utterances """ if hasattr(self, "_vector_index"): self.clear_cache() @@ -153,21 +167,23 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: self._cond_prob_true, self._cond_prob_false = self._compute_cond() def _compute_prior(self, y: NDArray[np.float64]) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - """ - Compute prior probabilities for each class. + """Compute prior probabilities for each class. + + Args: + y: Array of labels (multi-label format) - :param y: Array of labels (multi-label format). - :return: Tuple of prior probabilities for true and false labels. + Returns: + Tuple of prior probabilities for true and false labels """ prior_prob_true = (self.s + y.sum(axis=0)) / (self.s * 2 + y.shape[0]) prior_prob_false = 1 - prior_prob_true return prior_prob_true, prior_prob_false def _compute_cond(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - """ - Compute conditional probabilities for neighbors. + """Compute conditional probabilities for neighbors. - :return: Tuple of conditional probabilities for true and false labels. + Returns: + Tuple of conditional probabilities for true and false labels """ c = np.zeros((self._n_classes, self.k + 1), dtype=int) cn = np.zeros((self._n_classes, self.k + 1), dtype=int) @@ -193,6 +209,16 @@ def _get_neighbors( self, queries: list[str] | NDArray[Any], ) -> tuple[NDArray[np.int64], list[list[str]]]: + """Get nearest neighbors for given queries. + + Args: + queries: List of query utterances or embedded features + + Returns: + Tuple containing: + - Array of neighbor labels + - List of neighbor utterances + """ labels, _, neighbors = self._vector_index.query( queries, self.k + self.ignore_first_neighbours, @@ -203,31 +229,39 @@ def _get_neighbors( ) def predict_labels(self, utterances: list[str], thresh: float = 0.5) -> NDArray[np.int64]: - """ - Predict labels for the given utterances. + """Predict labels for the given utterances. + + Args: + utterances: List of query utterances + thresh: Threshold for binary classification, defaults to 0.5 - :param utterances: List of query utterances. - :param thresh: Threshold for binary classification, defaults to 0.5. - :return: Predicted labels as a binary array. + Returns: + Predicted labels as a binary array """ probas = self.predict(utterances) return (probas > thresh).astype(int) def predict(self, utterances: list[str]) -> NDArray[np.float64]: - """ - Predict probabilities for the given utterances. + """Predict probabilities for the given utterances. - :param utterances: List of query utterances. - :return: Array of predicted probabilities for each class. + Args: + utterances: List of query utterances + + Returns: + Array of predicted probabilities for each class """ return self._predict(utterances)[0] def predict_with_metadata(self, utterances: list[str]) -> tuple[NDArray[Any], list[dict[str, Any]] | None]: - """ - Predict probabilities along with metadata for the given utterances. + """Predict probabilities along with metadata for the given utterances. + + Args: + utterances: List of query utterances - :param utterances: List of query utterances. - :return: Tuple of probabilities and metadata with neighbor information. + Returns: + Tuple containing: + - Array of predicted probabilities + - List of metadata with neighbor information """ scores, neighbors = self._predict(utterances) metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors] @@ -241,6 +275,16 @@ def _predict( self, utterances: list[str], ) -> tuple[NDArray[np.float64], list[list[str]]]: + """Predict probabilities and retrieve neighbors for the given utterances. + + Args: + utterances: List of query utterances + + Returns: + Tuple containing: + - Array of predicted probabilities + - List of neighbor utterances + """ result = np.zeros((len(utterances), self._n_classes), dtype=float) neighbors_labels, neighbors = self._get_neighbors(utterances) diff --git a/autointent/modules/scoring/_sklearn/sklearn_scorer.py b/autointent/modules/scoring/_sklearn/sklearn_scorer.py index 64119b924..fbc8c4405 100644 --- a/autointent/modules/scoring/_sklearn/sklearn_scorer.py +++ b/autointent/modules/scoring/_sklearn/sklearn_scorer.py @@ -1,3 +1,5 @@ +"""Module for classification scoring using sklearn classifiers with predict_proba() method.""" + import logging from typing import Any @@ -10,7 +12,7 @@ from autointent import Context, Embedder from autointent.configs import EmbedderConfig, TaskTypeEnum from autointent.custom_types import ListOfLabels -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer logger = logging.getLogger(__name__) AVAILABLE_CLASSIFIERS = { @@ -28,13 +30,27 @@ class SklearnScorer(BaseScorer): - """ - Scoring module for classification using sklearn classifiers with implemented predict_proba() method. + """Scoring module for classification using sklearn classifiers. This module uses embeddings generated from a transformer model to train chosen sklearn classifier for intent classification. - :ivar name: Name of the scorer, defaults to "linear". + Attributes: + name: Name of the scorer, defaults to "sklearn" + supports_multilabel: Whether multilabel classification is supported + supports_multiclass: Whether multiclass classification is supported + + Examples: + >>> from autointent.modules.scoring import SklearnScorer + >>> utterances = ["hello", "how are you?"] + >>> labels = [0, 1] + >>> scorer = SklearnScorer( + ... clf_name="LogisticRegression", + ... embedder_config="sergeyzh/rubert-tiny-turbo", + ... ) + >>> scorer.fit(utterances, labels) + >>> test_utterances = ["hi", "what's up?"] + >>> probabilities = scorer.predict(test_utterances) """ name = "sklearn" @@ -47,12 +63,15 @@ def __init__( embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, **clf_args: Any, # noqa: ANN401 ) -> None: - """ - Initialize the SklearnScorer. + """Initialize the SklearnScorer. - :param embedder_config: Config of the embedder model. - :param clf_name: Name of the sklearn classifier to use. - :param clf_args: dictionary with the chosen sklearn classifier arguments. + Args: + clf_name: Name of the sklearn classifier to use + embedder_config: Config of the embedder model + **clf_args: Arguments for the chosen sklearn classifier + + Raises: + ValueError: If the specified classifier doesn't exist or lacks predict_proba """ self.embedder_config = EmbedderConfig.from_search_config(embedder_config) self.clf_name = clf_name @@ -73,14 +92,16 @@ def from_context( embedder_config: EmbedderConfig | str | None = None, **clf_args: float | str | bool, ) -> Self: - """ - Create a SklearnScorer instance using a Context object. + """Create a SklearnScorer instance using a Context object. - :param context: Context containing configurations and utilities. - :param clf_name: Name of the sklearn classifier to use. - :param clf_args: dictionary with the chosen sklearn classifier arguments, defaults to {}. - :param embedder_config: Config of the embedder, or None to use the best embedder. - :return: Initialized SklearnScorer instance. + Args: + context: Context containing configurations and utilities + clf_name: Name of the sklearn classifier to use + embedder_config: Config of the embedder, or None to use the best embedder + **clf_args: Arguments for the chosen sklearn classifier + + Returns: + Initialized SklearnScorer instance """ if embedder_config is None: embedder_config = context.resolve_embedder() @@ -96,12 +117,14 @@ def fit( utterances: list[str], labels: ListOfLabels, ) -> None: - """ - Train the chosen sklearn classifier. + """Train the chosen sklearn classifier. + + Args: + utterances: List of training utterances + labels: List of labels corresponding to the utterances - :param utterances: List of training utterances. - :param labels: List of labels corresponding to the utterances. - :raises ValueError: If the vector index mismatches the provided utterances. + Raises: + ValueError: If the vector index mismatches the provided utterances """ if hasattr(self, "_clf"): self.clear_cache() @@ -127,11 +150,13 @@ def fit( self._embedder = embedder def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Predict probabilities for the given utterances. + """Predict probabilities for the given utterances. + + Args: + utterances: List of query utterances - :param utterances: List of query utterances. - :return: Array of predicted probabilities for each class. + Returns: + Array of predicted probabilities for each class """ features = self._embedder.embed(utterances, TaskTypeEnum.classification) probas = self._clf.predict_proba(features) diff --git a/autointent/nodes/_inference_node.py b/autointent/nodes/_inference_node.py index 1f6fea6aa..21ecded95 100644 --- a/autointent/nodes/_inference_node.py +++ b/autointent/nodes/_inference_node.py @@ -6,7 +6,7 @@ from autointent.configs import InferenceNodeConfig from autointent.custom_types import NodeType -from autointent.modules.abc import BaseModule +from autointent.modules.base import BaseModule from autointent.nodes.info import NODES_INFO @@ -14,21 +14,21 @@ class InferenceNode: """Inference node class.""" def __init__(self, module: BaseModule, node_type: NodeType) -> None: - """ - Initialize the inference node. + """Initialize the inference node. - :param module: Module to use for inference - :param node_type: Node types + Args: + module: Module to use for inference + node_type: Node types """ self.module = module self.node_type = node_type @classmethod def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode": - """ - Initialize from config. + """Initialize from config. - :param config: Configuration for the node. + Args: + config: Config to init from """ node_info = NODES_INFO[config.node_type] module = node_info.modules_available[config.module_name](**config.module_config) diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 8771e4a81..3d4b4798e 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -1,4 +1,4 @@ -"""Node optimizer.""" +"""Node optimizer for optimizing module configurations.""" import gc import itertools as it @@ -21,24 +21,32 @@ class ParamSpaceInt(BaseModel): - low: int = Field(..., description="Low boundary of the search space.") - high: int = Field(..., description="High boundary of the search space.") - step: int = Field(1, description="Step of the search space.") - log: bool = Field(False, description="Whether to use a logarithmic scale.") + """Integer parameter search space configuration.""" + + low: int = Field(..., description="Lower boundary of the search space.") + high: int = Field(..., description="Upper boundary of the search space.") + step: int = Field(1, description="Step size for the search space.") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") class ParamSpaceFloat(BaseModel): - low: float = Field(..., description="Low boundary of the search space.") - high: float = Field(..., description="High boundary of the search space.") - step: float | None = Field(None, description="Step of the search space.") - log: bool = Field(False, description="Whether to use a logarithmic scale.") + """Float parameter search space configuration.""" + + low: float = Field(..., description="Lower boundary of the search space.") + high: float = Field(..., description="Upper boundary of the search space.") + step: float | None = Field(None, description="Step size for the search space (if applicable).") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") logger = logging.getLogger(__name__) class NodeOptimizer: - """Node optimizer class.""" + """Class for optimizing nodes in a computational pipeline. + + This class is responsible for optimizing different modules within a node + using various search strategies and logging the results. + """ def __init__( self, @@ -47,12 +55,13 @@ def __init__( target_metric: str, metrics: list[str] | None = None, ) -> None: - """ - Initialize the node optimizer. + """Initializes the node optimizer. - :param node_type: Node type - :param search_space: Search space for the optimization - :param metrics: Metrics to optimize. + Args: + node_type: The type of node being optimized. + search_space: A list of dictionaries defining the search space. + target_metric: The primary metric to optimize. + metrics: Additional metrics to track during optimization. """ self._logger = logger self.node_type = node_type @@ -67,20 +76,22 @@ def __init__( self.modules_search_spaces = search_space def fit(self, context: Context, sampler: SamplerType = "brute") -> None: - """ - Fit the node optimizer. + """Performs the optimization process for the node. - :param context: Context - :param sampler: Sampler to use for optimization + Args: + context: The optimization context containing relevant data. + sampler: The sampling strategy used for optimization. + + Raises: + AssertionError: If an invalid sampler type is provided. """ - self._logger.info("starting %s node optimization...", self.node_info.node_type) + self._logger.info("Starting %s node optimization...", self.node_info.node_type) for search_space in deepcopy(self.modules_search_spaces): self._counter: int = 0 module_name = search_space.pop("module_name") - n_trials = None - if "n_trials" in search_space: - n_trials = search_space.pop("n_trials") + n_trials = search_space.pop("n_trials", None) + if sampler == "tpe": sampler_instance = optuna.samplers.TPESampler(seed=context.seed) n_trials = n_trials or 10 @@ -92,6 +103,7 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None: n_trials = n_trials or 10 else: assert_never(sampler) + study = optuna.create_study(direction="maximize", sampler=sampler_instance) optuna.logging.set_verbosity(optuna.logging.WARNING) obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context) @@ -106,9 +118,20 @@ def objective( search_space: dict[str, ParamSpaceInt | ParamSpaceFloat | list[Any]], context: Context, ) -> float: + """Defines the objective function for optimization. + + Args: + trial: The Optuna trial instance. + module_name: The name of the module being optimized. + search_space: The parameter search space. + context: The execution context. + + Returns: + The value of the target metric for the given trial. + """ config = self.suggest(trial, search_space) - self._logger.debug("initializing %s module...", module_name) + self._logger.debug("Initializing %s module...", module_name) module = self.node_info.modules_available[module_name].from_context(context, **config) embedder_config = module.get_embedder_config() @@ -117,7 +140,7 @@ def objective( context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config) - self._logger.debug("scoring %s module...", module_name) + self._logger.debug("Scoring %s module...", module_name) all_metrics = module.score(context, metrics=self.metrics) target_metric = all_metrics[self.target_metric] @@ -150,10 +173,21 @@ def objective( torch.cuda.empty_cache() self._counter += 1 - return target_metric def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]: + """Suggests parameter values based on the search space. + + Args: + trial: The Optuna trial instance. + search_space: A dictionary defining the parameter search space. + + Returns: + A dictionary containing the suggested parameter values. + + Raises: + TypeError: If an unsupported parameter search space type is encountered. + """ res: dict[str, Any] = {} for param_name, param_space in search_space.items(): @@ -178,23 +212,29 @@ def _is_valid_param_space( return False def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: - """ - Get module dump directory. + """Creates and returns the path to the module dump directory. + + Args: + dump_dir: The base directory for storing module dumps. + module_name: The name of the module being optimized. + j_combination: The combination index for the parameters. - :param dump_dir: The base directory where the module dump directories will be created. - :param module_name: The type of the module being optimized. - :param j_combination: The index of the parameter combination being used. - :return: The path to the module dump directory as a string. + Returns: + The path to the module dump directory. """ dump_dir_ = dump_dir / self.node_info.node_type / module_name / f"comb_{j_combination}" dump_dir_.mkdir(parents=True, exist_ok=True) return str(dump_dir_) def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None: - """ - Validate nodes with dataset. + """Validates nodes against the dataset. + + Args: + dataset: The dataset used for validation. + mode: The validation mode ("raise" or "warning"). - :param dataset: Dataset to use + Raises: + ValueError: If validation fails and `mode` is set to "raise". """ is_multilabel = dataset.multilabel diff --git a/autointent/nodes/info/_base.py b/autointent/nodes/info/_base.py index eb6da4870..95aa95e9e 100644 --- a/autointent/nodes/info/_base.py +++ b/autointent/nodes/info/_base.py @@ -5,7 +5,7 @@ from autointent.custom_types import NodeType from autointent.metrics import METRIC_FN -from autointent.modules.abc import BaseModule +from autointent.modules.base import BaseModule class NodeInfo: diff --git a/autointent/nodes/info/_decision.py b/autointent/nodes/info/_decision.py index b2a62667f..971309afb 100644 --- a/autointent/nodes/info/_decision.py +++ b/autointent/nodes/info/_decision.py @@ -6,7 +6,7 @@ from autointent.custom_types import NodeType from autointent.metrics import DECISION_METRICS, DecisionMetricFn from autointent.modules import DECISION_MODULES -from autointent.modules.abc import BaseDecision +from autointent.modules.base import BaseDecision from ._base import NodeInfo diff --git a/autointent/nodes/info/_embedding.py b/autointent/nodes/info/_embedding.py index 0c82e0e6f..6ce22ce65 100644 --- a/autointent/nodes/info/_embedding.py +++ b/autointent/nodes/info/_embedding.py @@ -13,7 +13,7 @@ ScoringMetricFn, ) from autointent.modules import EMBEDDING_MODULES -from autointent.modules.abc import BaseEmbedding +from autointent.modules.base import BaseEmbedding from ._base import NodeInfo diff --git a/autointent/nodes/info/_regex.py b/autointent/nodes/info/_regex.py index 8703cadca..0f03ef261 100644 --- a/autointent/nodes/info/_regex.py +++ b/autointent/nodes/info/_regex.py @@ -6,7 +6,7 @@ from autointent.custom_types import NodeType from autointent.metrics import REGEX_METRICS from autointent.metrics.regex import RegexMetricFn -from autointent.modules.abc import BaseRegex +from autointent.modules.base import BaseRegex from autointent.modules.regex import Regex from ._base import NodeInfo diff --git a/autointent/nodes/info/_scoring.py b/autointent/nodes/info/_scoring.py index 23a794719..07cbe0f72 100644 --- a/autointent/nodes/info/_scoring.py +++ b/autointent/nodes/info/_scoring.py @@ -6,7 +6,7 @@ from autointent.custom_types import NodeType from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL, ScoringMetricFn from autointent.modules import SCORING_MODULES -from autointent.modules.abc import BaseScorer +from autointent.modules.base import BaseScorer from ._base import NodeInfo diff --git a/autointent/schemas/_schemas.py b/autointent/schemas/_schemas.py index 8005a91d2..507d98a73 100644 --- a/autointent/schemas/_schemas.py +++ b/autointent/schemas/_schemas.py @@ -16,8 +16,7 @@ class Tag(BaseModel): - """ - Represents a tag associated with intent classes. + """Represents a tag associated with intent classes. Tags are used to define constraints such that if two intent classes share a common tag, they cannot both be assigned to the same sample. @@ -46,8 +45,7 @@ def load(cls, path: Path) -> "TagsList": class Sample(BaseModel): - """ - Represents a sample with an utterance and an optional label. + """Represents a sample with an utterance and an optional label. :param utterance: The textual content of the sample. :param label: The label(s) associated with the sample. Can be a single label (integer) @@ -59,8 +57,7 @@ class Sample(BaseModel): @model_validator(mode="after") def validate_sample(self) -> "Sample": - """ - Validate the sample after model instantiation. + """Validate the sample after model instantiation. This method ensures that the `label` field adheres to the expected constraints: - If `label` is provided, it must be a non-negative integer or a list of non-negative integers. @@ -73,8 +70,7 @@ def validate_sample(self) -> "Sample": return self._validate_label() def _validate_label(self) -> "Sample": - """ - Validate the `label` field of the sample. + """Validate the `label` field of the sample. - Ensures that the `label` is not empty for multilabel samples. - Validates that all provided labels are non-negative integers. diff --git a/autointent/utils.py b/autointent/utils.py index f56f408d4..a491e3363 100644 --- a/autointent/utils.py +++ b/autointent/utils.py @@ -10,21 +10,26 @@ def load_search_space(path: Path | str) -> list[dict[str, Any]]: - """ - Load hyperparameters search space from file. + """Load hyperparameters search space from file. + + Args: + path: Path to the search space file. - :param path: path to yaml file - :return: + Returns: + List of dictionaries representing the search space. """ with Path(path).open() as file: return yaml.safe_load(file) # type: ignore[no-any-return] def load_preset(name: SearchSpacePresets) -> dict[str, Any]: - """ - Load one of preset search spaces. + """Load one of preset search spaces. + + Args: + name: Name of the preset search space. - :param name: name of a presets. + Returns: + Dictionary representing the preset search space. """ path = ires.files("autointent._presets").joinpath(name + ".yaml") with path.open() as file: diff --git a/docs/source/conf.py b/docs/source/conf.py index 825ce2431..09029111a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,6 +49,7 @@ "nbsphinx", "sphinx.ext.intersphinx", "sphinx_multiversion", + "sphinx.ext.napoleon", ] templates_path = ["_templates"] @@ -81,6 +82,21 @@ suppress_warnings = ["autoapi.python_import_resolution"] autoapi_add_toctree_entry = False +# Napoleon settings +napoleon_google_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False +napoleon_type_aliases = None +napoleon_attr_annotations = True + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/pyproject.toml b/pyproject.toml index e8867b7b7..2189a70d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ ignore = [ max-args = 10 [tool.ruff.lint.pydocstyle] -convention = "pep257" +convention = "google" [build-system] requires = ["poetry-core>=2.0"]