Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 25 additions & 58 deletions autointent/_dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ class Dataset(dict[str, HFDataset]):

This class extends a dictionary where the keys represent dataset splits (e.g., 'train', 'test'),
and the values are Hugging Face datasets.

Attributes:
label_feature: The feature name corresponding to labels in the dataset.
utterance_feature: The feature name corresponding to utterances in the dataset.
has_descriptions: Whether the dataset includes descriptions for intents.
"""

label_feature = "label"
utterance_feature = "utterance"
label_feature: str = "label"
"""The feature name corresponding to labels in the dataset."""

utterance_feature: str = "utterance"
"""The feature name corresponding to utterances in the dataset"""

has_descriptions: bool
"""Whether the dataset includes descriptions for intents."""

intents: list[Intent]
"""All metadata about intents used in this dataset."""

def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
"""Initializes the dataset.
Expand All @@ -59,21 +62,13 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #

@property
def multilabel(self) -> bool:
"""Checks if the dataset is multilabel.

Returns:
True if the dataset supports multilabel classification, False otherwise.
"""
"""Checks if the dataset is multilabel."""
split = Split.TRAIN if Split.TRAIN in self else f"{Split.TRAIN}_0"
return isinstance(self[split].features[self.label_feature], Sequence)

@cached_property
def n_classes(self) -> int:
"""Returns the number of classes in the dataset.

Returns:
The number of unique classes in the training split.
"""
"""Returns the number of classes in the dataset."""
return len(self.intents)

@classmethod
Expand All @@ -82,9 +77,6 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":

Args:
mapping: A dictionary representation of the dataset.

Returns:
A `Dataset` instance initialized from the dictionary.
"""
from ._reader import DictReader

Expand All @@ -96,48 +88,37 @@ def from_json(cls, filepath: str | Path) -> "Dataset":

Args:
filepath: Path to the JSON file.

Returns:
A `Dataset` instance initialized from the JSON file.
"""
from ._reader import JsonReader

return JsonReader().read(filepath)

@classmethod
def from_hub(cls, repo_id: str) -> "Dataset":
def from_hub(cls, repo_name: str) -> "Dataset":
"""Loads a dataset from the Hugging Face Hub.

Args:
repo_id: The ID of the Hugging Face repository.

Returns:
A `Dataset` instance initialized from the Hugging Face dataset repository.
repo_name: The name of the Hugging Face repository, like `AutoIntent/clinc150`.
"""
from ._reader import DictReader

splits = load_dataset(repo_id)
splits = load_dataset(repo_name)
mapping = dict(**splits)
if Split.INTENTS in get_dataset_config_names(repo_id):
mapping["intents"] = load_dataset(repo_id, Split.INTENTS)[Split.INTENTS].to_list()
if Split.INTENTS in get_dataset_config_names(repo_name):
mapping["intents"] = load_dataset(repo_name, Split.INTENTS)[Split.INTENTS].to_list()

return DictReader().read(mapping)

def to_multilabel(self) -> "Dataset":
"""Converts dataset labels to multilabel format.

Returns:
The dataset with labels converted to multilabel format.
"""
"""Converts dataset labels to multilabel format."""
for split_name, split in self.items():
self[split_name] = split.map(self._to_multilabel)
return self

def to_dict(self) -> dict[str, list[dict[str, Any]]]:
"""Converts the dataset into a dictionary format.

Returns:
A dictionary where the keys are dataset splits and the values are lists of samples.
Returns a dictionary where the keys are dataset splits and the values are lists of samples.
"""
mapping = {split_name: split.to_list() for split_name, split in self.items()}
mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents]
Expand All @@ -155,26 +136,22 @@ def to_json(self, filepath: str | Path) -> None:
with path.open("w") as file:
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)

def push_to_hub(self, repo_id: str, private: bool = False) -> None:
def push_to_hub(self, repo_name: str, private: bool = False) -> None:
"""Uploads the dataset to the Hugging Face Hub.

Args:
repo_id: The ID of the Hugging Face repository.
repo_name: The ID of the Hugging Face repository.
private: Whether to make the repository private.
"""
for split_name, split in self.items():
split.push_to_hub(repo_id, split=split_name, private=private)
split.push_to_hub(repo_name, split=split_name, private=private)

if self.intents:
intents = HFDataset.from_list([intent.model_dump() for intent in self.intents])
intents.push_to_hub(repo_id, config_name=Split.INTENTS, split=Split.INTENTS)
intents.push_to_hub(repo_name, config_name=Split.INTENTS, split=Split.INTENTS)

def get_tags(self) -> list[Tag]:
"""Extracts unique tags from the dataset's intents.

Returns:
A list of `Tag` objects containing unique tag names and associated intent IDs.
"""
"""Extracts unique tags from the dataset's intents."""
tag_mapping = defaultdict(list)
for intent in self.intents:
for tag in intent.tags:
Expand All @@ -186,9 +163,6 @@ def get_n_classes(self, split: str) -> int:

Args:
split: The dataset split to analyze.

Returns:
The number of unique classes in the split.
"""
classes = set()
for label in self[split][self.label_feature]:
Expand All @@ -206,9 +180,6 @@ def _to_multilabel(self, sample: Sample) -> Sample:

Args:
sample: A sample from the dataset.

Returns:
The sample with its label converted to a multilabel format.
"""
if isinstance(sample["label"], int):
ohe_vector = [0] * self.n_classes
Expand All @@ -217,11 +188,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
return sample

def validate_descriptions(self) -> bool:
"""Validates whether all intents in the dataset contain descriptions.

Returns:
True if all intents have descriptions, False otherwise.
"""
"""Validates whether all intents in the dataset contain descriptions."""
has_any = any(intent.description is not None for intent in self.intents)
has_all = all(intent.description is not None for intent in self.intents)

Expand Down
65 changes: 31 additions & 34 deletions autointent/_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,30 @@ class EmbedderDumpMetadata(TypedDict):


class Embedder:
"""A wrapper for managing embedding models using Sentence Transformers.
"""A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.

This class handles initialization, saving, loading, and clearing of
embedding models, as well as calculating embeddings for input texts.
"""

metadata_dict_name: str = "metadata.json"
dump_dir: Path | None = None
_metadata_dict_name: str = "metadata.json"
_dump_dir: Path | None = None
config: EmbedderConfig
embedding_model: SentenceTransformer

def __init__(self, embedder_config: EmbedderConfig) -> None:
"""Initialize the Embedder.

Args:
embedder_config: Config of embedder.
"""
self.model_name = embedder_config.model_name
self.device = embedder_config.device
self.batch_size = embedder_config.batch_size
self.max_length = embedder_config.max_length
self.use_cache = embedder_config.use_cache
self.embedding_config = embedder_config
self.config = embedder_config

self.embedding_model = SentenceTransformer(
self.model_name, device=self.device, prompts=embedder_config.get_prompt_config()
self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config()
)

self.logger = logging.getLogger(__name__)
self._logger = logging.getLogger(__name__)

def __hash__(self) -> int:
"""Compute a hash value for the Embedder.
Expand All @@ -90,38 +87,38 @@ def __hash__(self) -> int:
hasher = Hasher()
for parameter in self.embedding_model.parameters():
hasher.update(parameter.detach().cpu().numpy())
hasher.update(self.max_length)
hasher.update(self.config.max_length)
return hasher.intdigest()

def clear_ram(self) -> None:
"""Move the embedding model to CPU and delete it from memory."""
self.logger.debug("Clearing embedder %s from memory", self.model_name)
self._logger.debug("Clearing embedder %s from memory", self.config.model_name)
self.embedding_model.cpu()
del self.embedding_model
torch.cuda.empty_cache()

def delete(self) -> None:
"""Delete the embedding model and its associated directory."""
self.clear_ram()
if self.dump_dir is not None:
shutil.rmtree(self.dump_dir)
if self._dump_dir is not None:
shutil.rmtree(self._dump_dir)

def dump(self, path: Path) -> None:
"""Save the embedding model and metadata to disk.

Args:
path: Path to the directory where the model will be saved.
"""
self.dump_dir = path
self._dump_dir = path
metadata = EmbedderDumpMetadata(
model_name=str(self.model_name),
device=self.device,
batch_size=self.batch_size,
max_length=self.max_length,
use_cache=self.use_cache,
model_name=str(self.config.model_name),
device=self.config.device,
batch_size=self.config.batch_size,
max_length=self.config.max_length,
use_cache=self.config.use_cache,
)
path.mkdir(parents=True, exist_ok=True)
with (path / self.metadata_dict_name).open("w") as file:
with (path / self._metadata_dict_name).open("w") as file:
json.dump(metadata, file, indent=4)

@classmethod
Expand All @@ -132,7 +129,7 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
path: Path to the directory where the model is stored.
override_config: one can override presaved settings
"""
with (Path(path) / cls.metadata_dict_name).open() as file:
with (Path(path) / cls._metadata_dict_name).open() as file:
metadata: EmbedderDumpMetadata = json.load(file)

if override_config is not None:
Expand All @@ -152,7 +149,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
Returns:
A numpy array of embeddings.
"""
if self.use_cache:
if self.config.use_cache:
hasher = Hasher()
hasher.update(self)
hasher.update(utterances)
Expand All @@ -161,26 +158,26 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
if embeddings_path.exists():
return np.load(embeddings_path) # type: ignore[no-any-return]

self.logger.debug(
self._logger.debug(
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
self.model_name,
self.batch_size,
str(self.max_length),
self.device,
self.config.model_name,
self.config.batch_size,
str(self.config.max_length),
self.config.device,
)

if self.max_length is not None:
self.embedding_model.max_seq_length = self.max_length
if self.config.max_length is not None:
self.embedding_model.max_seq_length = self.config.max_length

embeddings = self.embedding_model.encode(
utterances,
convert_to_numpy=True,
batch_size=self.batch_size,
batch_size=self.config.batch_size,
normalize_embeddings=True,
prompt_name=self.embedding_config.get_prompt_type(task_type),
prompt_name=self.config.get_prompt_type(task_type),
)

if self.use_cache:
if self.config.use_cache:
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
np.save(embeddings_path, embeddings)

Expand Down
5 changes: 3 additions & 2 deletions autointent/_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@


class Hasher:
"""A class that provides methods for hashing data using xxhash.
"""A class that provides methods for hashing data using `xxhash <https://github.com/ifduyue/python-xxhash>`_.

This class supports both a class-level method for generating hashes from
any given value, as well as an instance-level method for progressively
updating a hash state with new values.
updating a hash state with new values. We use this class for
hashing embeddings from :py:class:`autointent.Embedder`.
"""

def __init__(self) -> None:
Expand Down
15 changes: 14 additions & 1 deletion autointent/_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@


class OptimizationConfig(BaseModel):
"""Configuration for the optimization process."""
"""Configuration for the optimization process.

One can use it to customize optimization beyond choosing different preset.
Instantiate it and pass to :py:meth:`autointent.Pipeline.from_optimization_config`.
"""

data_config: DataConfig = DataConfig()

search_space: list[dict[str, Any]]
"""See tutorial on search space customization."""

logging_config: LoggingConfig = LoggingConfig()
"""See tutorial on logging configuration."""

embedder_config: EmbedderConfig = EmbedderConfig()

cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()

sampler: SamplerType = "brute"
"""See tutorial on optuna and presets."""

seed: PositiveInt = 42
Loading