Skip to content

Commit cc44a1e

Browse files
Docs/update a lot (#158)
* add inventories for `sentence_transformers`, `sklearn`, `numpy` * refactor `Embedder` * refactor `Ranker` * `custom_types` -> `types` * `types` -> `custom_types` * minor fix * refactor `Dataset` * minor fix * refactor `Hasher` * fix codestyle * refactor `Context` * fix codestyle * refactor `OptimizationConfig` * refactor `Pipeline` * fix codestyle * update release info * minor fix * forbid inference node config without module dumped * minor changes * refactor `DataHandler` * Update optimizer_config.schema.json * fix stratification issues * upd test on vector index * fix codestyle --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent de4d37d commit cc44a1e

30 files changed

+338
-333
lines changed

autointent/_dataset/_dataset.py

Lines changed: 25 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,19 @@ class Dataset(dict[str, HFDataset]):
3333
3434
This class extends a dictionary where the keys represent dataset splits (e.g., 'train', 'test'),
3535
and the values are Hugging Face datasets.
36-
37-
Attributes:
38-
label_feature: The feature name corresponding to labels in the dataset.
39-
utterance_feature: The feature name corresponding to utterances in the dataset.
40-
has_descriptions: Whether the dataset includes descriptions for intents.
4136
"""
4237

43-
label_feature = "label"
44-
utterance_feature = "utterance"
38+
label_feature: str = "label"
39+
"""The feature name corresponding to labels in the dataset."""
40+
41+
utterance_feature: str = "utterance"
42+
"""The feature name corresponding to utterances in the dataset"""
43+
4544
has_descriptions: bool
45+
"""Whether the dataset includes descriptions for intents."""
46+
47+
intents: list[Intent]
48+
"""All metadata about intents used in this dataset."""
4649

4750
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
4851
"""Initializes the dataset.
@@ -59,21 +62,13 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
5962

6063
@property
6164
def multilabel(self) -> bool:
62-
"""Checks if the dataset is multilabel.
63-
64-
Returns:
65-
True if the dataset supports multilabel classification, False otherwise.
66-
"""
65+
"""Checks if the dataset is multilabel."""
6766
split = Split.TRAIN if Split.TRAIN in self else f"{Split.TRAIN}_0"
6867
return isinstance(self[split].features[self.label_feature], Sequence)
6968

7069
@cached_property
7170
def n_classes(self) -> int:
72-
"""Returns the number of classes in the dataset.
73-
74-
Returns:
75-
The number of unique classes in the training split.
76-
"""
71+
"""Returns the number of classes in the dataset."""
7772
return len(self.intents)
7873

7974
@classmethod
@@ -82,9 +77,6 @@ def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
8277
8378
Args:
8479
mapping: A dictionary representation of the dataset.
85-
86-
Returns:
87-
A `Dataset` instance initialized from the dictionary.
8880
"""
8981
from ._reader import DictReader
9082

@@ -96,48 +88,37 @@ def from_json(cls, filepath: str | Path) -> "Dataset":
9688
9789
Args:
9890
filepath: Path to the JSON file.
99-
100-
Returns:
101-
A `Dataset` instance initialized from the JSON file.
10291
"""
10392
from ._reader import JsonReader
10493

10594
return JsonReader().read(filepath)
10695

10796
@classmethod
108-
def from_hub(cls, repo_id: str) -> "Dataset":
97+
def from_hub(cls, repo_name: str) -> "Dataset":
10998
"""Loads a dataset from the Hugging Face Hub.
11099
111100
Args:
112-
repo_id: The ID of the Hugging Face repository.
113-
114-
Returns:
115-
A `Dataset` instance initialized from the Hugging Face dataset repository.
101+
repo_name: The name of the Hugging Face repository, like `AutoIntent/clinc150`.
116102
"""
117103
from ._reader import DictReader
118104

119-
splits = load_dataset(repo_id)
105+
splits = load_dataset(repo_name)
120106
mapping = dict(**splits)
121-
if Split.INTENTS in get_dataset_config_names(repo_id):
122-
mapping["intents"] = load_dataset(repo_id, Split.INTENTS)[Split.INTENTS].to_list()
107+
if Split.INTENTS in get_dataset_config_names(repo_name):
108+
mapping["intents"] = load_dataset(repo_name, Split.INTENTS)[Split.INTENTS].to_list()
123109

124110
return DictReader().read(mapping)
125111

126112
def to_multilabel(self) -> "Dataset":
127-
"""Converts dataset labels to multilabel format.
128-
129-
Returns:
130-
The dataset with labels converted to multilabel format.
131-
"""
113+
"""Converts dataset labels to multilabel format."""
132114
for split_name, split in self.items():
133115
self[split_name] = split.map(self._to_multilabel)
134116
return self
135117

136118
def to_dict(self) -> dict[str, list[dict[str, Any]]]:
137119
"""Converts the dataset into a dictionary format.
138120
139-
Returns:
140-
A dictionary where the keys are dataset splits and the values are lists of samples.
121+
Returns a dictionary where the keys are dataset splits and the values are lists of samples.
141122
"""
142123
mapping = {split_name: split.to_list() for split_name, split in self.items()}
143124
mapping[Split.INTENTS] = [intent.model_dump() for intent in self.intents]
@@ -155,26 +136,22 @@ def to_json(self, filepath: str | Path) -> None:
155136
with path.open("w") as file:
156137
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
157138

158-
def push_to_hub(self, repo_id: str, private: bool = False) -> None:
139+
def push_to_hub(self, repo_name: str, private: bool = False) -> None:
159140
"""Uploads the dataset to the Hugging Face Hub.
160141
161142
Args:
162-
repo_id: The ID of the Hugging Face repository.
143+
repo_name: The ID of the Hugging Face repository.
163144
private: Whether to make the repository private.
164145
"""
165146
for split_name, split in self.items():
166-
split.push_to_hub(repo_id, split=split_name, private=private)
147+
split.push_to_hub(repo_name, split=split_name, private=private)
167148

168149
if self.intents:
169150
intents = HFDataset.from_list([intent.model_dump() for intent in self.intents])
170-
intents.push_to_hub(repo_id, config_name=Split.INTENTS, split=Split.INTENTS)
151+
intents.push_to_hub(repo_name, config_name=Split.INTENTS, split=Split.INTENTS)
171152

172153
def get_tags(self) -> list[Tag]:
173-
"""Extracts unique tags from the dataset's intents.
174-
175-
Returns:
176-
A list of `Tag` objects containing unique tag names and associated intent IDs.
177-
"""
154+
"""Extracts unique tags from the dataset's intents."""
178155
tag_mapping = defaultdict(list)
179156
for intent in self.intents:
180157
for tag in intent.tags:
@@ -186,9 +163,6 @@ def get_n_classes(self, split: str) -> int:
186163
187164
Args:
188165
split: The dataset split to analyze.
189-
190-
Returns:
191-
The number of unique classes in the split.
192166
"""
193167
classes = set()
194168
for label in self[split][self.label_feature]:
@@ -206,9 +180,6 @@ def _to_multilabel(self, sample: Sample) -> Sample:
206180
207181
Args:
208182
sample: A sample from the dataset.
209-
210-
Returns:
211-
The sample with its label converted to a multilabel format.
212183
"""
213184
if isinstance(sample["label"], int):
214185
ohe_vector = [0] * self.n_classes
@@ -217,11 +188,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
217188
return sample
218189

219190
def validate_descriptions(self) -> bool:
220-
"""Validates whether all intents in the dataset contain descriptions.
221-
222-
Returns:
223-
True if all intents have descriptions, False otherwise.
224-
"""
191+
"""Validates whether all intents in the dataset contain descriptions."""
225192
has_any = any(intent.description is not None for intent in self.intents)
226193
has_all = all(intent.description is not None for intent in self.intents)
227194

autointent/_embedder.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,33 +53,30 @@ class EmbedderDumpMetadata(TypedDict):
5353

5454

5555
class Embedder:
56-
"""A wrapper for managing embedding models using Sentence Transformers.
56+
"""A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.
5757
5858
This class handles initialization, saving, loading, and clearing of
5959
embedding models, as well as calculating embeddings for input texts.
6060
"""
6161

62-
metadata_dict_name: str = "metadata.json"
63-
dump_dir: Path | None = None
62+
_metadata_dict_name: str = "metadata.json"
63+
_dump_dir: Path | None = None
64+
config: EmbedderConfig
65+
embedding_model: SentenceTransformer
6466

6567
def __init__(self, embedder_config: EmbedderConfig) -> None:
6668
"""Initialize the Embedder.
6769
6870
Args:
6971
embedder_config: Config of embedder.
7072
"""
71-
self.model_name = embedder_config.model_name
72-
self.device = embedder_config.device
73-
self.batch_size = embedder_config.batch_size
74-
self.max_length = embedder_config.max_length
75-
self.use_cache = embedder_config.use_cache
76-
self.embedding_config = embedder_config
73+
self.config = embedder_config
7774

7875
self.embedding_model = SentenceTransformer(
79-
self.model_name, device=self.device, prompts=embedder_config.get_prompt_config()
76+
self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config()
8077
)
8178

82-
self.logger = logging.getLogger(__name__)
79+
self._logger = logging.getLogger(__name__)
8380

8481
def __hash__(self) -> int:
8582
"""Compute a hash value for the Embedder.
@@ -90,38 +87,38 @@ def __hash__(self) -> int:
9087
hasher = Hasher()
9188
for parameter in self.embedding_model.parameters():
9289
hasher.update(parameter.detach().cpu().numpy())
93-
hasher.update(self.max_length)
90+
hasher.update(self.config.max_length)
9491
return hasher.intdigest()
9592

9693
def clear_ram(self) -> None:
9794
"""Move the embedding model to CPU and delete it from memory."""
98-
self.logger.debug("Clearing embedder %s from memory", self.model_name)
95+
self._logger.debug("Clearing embedder %s from memory", self.config.model_name)
9996
self.embedding_model.cpu()
10097
del self.embedding_model
10198
torch.cuda.empty_cache()
10299

103100
def delete(self) -> None:
104101
"""Delete the embedding model and its associated directory."""
105102
self.clear_ram()
106-
if self.dump_dir is not None:
107-
shutil.rmtree(self.dump_dir)
103+
if self._dump_dir is not None:
104+
shutil.rmtree(self._dump_dir)
108105

109106
def dump(self, path: Path) -> None:
110107
"""Save the embedding model and metadata to disk.
111108
112109
Args:
113110
path: Path to the directory where the model will be saved.
114111
"""
115-
self.dump_dir = path
112+
self._dump_dir = path
116113
metadata = EmbedderDumpMetadata(
117-
model_name=str(self.model_name),
118-
device=self.device,
119-
batch_size=self.batch_size,
120-
max_length=self.max_length,
121-
use_cache=self.use_cache,
114+
model_name=str(self.config.model_name),
115+
device=self.config.device,
116+
batch_size=self.config.batch_size,
117+
max_length=self.config.max_length,
118+
use_cache=self.config.use_cache,
122119
)
123120
path.mkdir(parents=True, exist_ok=True)
124-
with (path / self.metadata_dict_name).open("w") as file:
121+
with (path / self._metadata_dict_name).open("w") as file:
125122
json.dump(metadata, file, indent=4)
126123

127124
@classmethod
@@ -132,7 +129,7 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
132129
path: Path to the directory where the model is stored.
133130
override_config: one can override presaved settings
134131
"""
135-
with (Path(path) / cls.metadata_dict_name).open() as file:
132+
with (Path(path) / cls._metadata_dict_name).open() as file:
136133
metadata: EmbedderDumpMetadata = json.load(file)
137134

138135
if override_config is not None:
@@ -152,7 +149,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
152149
Returns:
153150
A numpy array of embeddings.
154151
"""
155-
if self.use_cache:
152+
if self.config.use_cache:
156153
hasher = Hasher()
157154
hasher.update(self)
158155
hasher.update(utterances)
@@ -161,26 +158,26 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
161158
if embeddings_path.exists():
162159
return np.load(embeddings_path) # type: ignore[no-any-return]
163160

164-
self.logger.debug(
161+
self._logger.debug(
165162
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
166-
self.model_name,
167-
self.batch_size,
168-
str(self.max_length),
169-
self.device,
163+
self.config.model_name,
164+
self.config.batch_size,
165+
str(self.config.max_length),
166+
self.config.device,
170167
)
171168

172-
if self.max_length is not None:
173-
self.embedding_model.max_seq_length = self.max_length
169+
if self.config.max_length is not None:
170+
self.embedding_model.max_seq_length = self.config.max_length
174171

175172
embeddings = self.embedding_model.encode(
176173
utterances,
177174
convert_to_numpy=True,
178-
batch_size=self.batch_size,
175+
batch_size=self.config.batch_size,
179176
normalize_embeddings=True,
180-
prompt_name=self.embedding_config.get_prompt_type(task_type),
177+
prompt_name=self.config.get_prompt_type(task_type),
181178
)
182179

183-
if self.use_cache:
180+
if self.config.use_cache:
184181
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
185182
np.save(embeddings_path, embeddings)
186183

autointent/_hash.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88

99
class Hasher:
10-
"""A class that provides methods for hashing data using xxhash.
10+
"""A class that provides methods for hashing data using `xxhash <https://github.com/ifduyue/python-xxhash>`_.
1111
1212
This class supports both a class-level method for generating hashes from
1313
any given value, as well as an instance-level method for progressively
14-
updating a hash state with new values.
14+
updating a hash state with new values. We use this class for
15+
hashing embeddings from :py:class:`autointent.Embedder`.
1516
"""
1617

1718
def __init__(self) -> None:

autointent/_optimization_config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@
77

88

99
class OptimizationConfig(BaseModel):
10-
"""Configuration for the optimization process."""
10+
"""Configuration for the optimization process.
11+
12+
One can use it to customize optimization beyond choosing different preset.
13+
Instantiate it and pass to :py:meth:`autointent.Pipeline.from_optimization_config`.
14+
"""
1115

1216
data_config: DataConfig = DataConfig()
17+
1318
search_space: list[dict[str, Any]]
19+
"""See tutorial on search space customization."""
20+
1421
logging_config: LoggingConfig = LoggingConfig()
22+
"""See tutorial on logging configuration."""
23+
1524
embedder_config: EmbedderConfig = EmbedderConfig()
25+
1626
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
27+
1728
sampler: SamplerType = "brute"
29+
"""See tutorial on optuna and presets."""
30+
1831
seed: PositiveInt = 42

0 commit comments

Comments
 (0)