Skip to content

Commit 9200e05

Browse files
committed
pull actual code
2 parents 7338d7e + d9807cc commit 9200e05

File tree

115 files changed

+2586
-2896
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+2586
-2896
lines changed

CONTRIBUTING.md

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,6 @@ make lint
5050

5151
![](assets/dependency-graph.png)
5252

53-
## Настройка логгера
54-
Чтобы видеть debug строчки у вас есть несколько опций:
55-
56-
1. Включить весь debug output через опцию коммандной строки:
57-
```bash
58-
autointent hydra.verbose=true
59-
```
60-
2. Включить debug output только для определенных модулей, пример для autointent.pipeline.optimization.cli_endpoint и самой hydra:
61-
```bash
62-
autointent hydra.verbose=[hydra,autointent/pipeline/optimization/cli_endpoint] hydra.job_logging.root.level=DEBUG
63-
```
64-
65-
Само конфигурирование логгера сделано в autointent.configs.optimization_cli.logger_config. Вы можете изменить любой параметр логгера через коммандную строку. Вот пример, как поменять уровень логгера на ERROR:
66-
```bash
67-
autointent hydra.job_logging.root.level=ERROR
68-
```
69-
70-
Еще можно изменить параметры логгера через yaml файлы:
71-
1. Создадим папку с конфиг. файлами: test_config
72-
2. test_config/config.yaml:
73-
```yaml
74-
defaults:
75-
- optimization_config
76-
- _self_
77-
- override hydra/job_logging: custom
78-
79-
# set your config params for optimization here
80-
embedder_batch_size: 32
81-
```
82-
3. Поместите конфигурацию логгера в test_config/hydra/job_logging/custom.yaml (параметры см. [здесь](https://docs.python.org/3/howto/logging.html))
83-
4. Запускаем с конфиг файлом config.yaml:
84-
```bash
85-
autointent --config-path FULL_PATH/test_config --config-name config
86-
```
87-
8853
## Построение документации
8954

9055
Построить html версию в папке `docs/build`:

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ lint:
2424

2525
.PHONY: sync
2626
sync:
27-
poetry sync
27+
poetry sync --with dev,test,typing,docs
2828

2929
.PHONY: docs
3030
docs:

autointent/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
"""This is AutoIntent API reference."""
22

3-
from ._embedder import Embedder
43
from ._dataset import Dataset
4+
from ._embedder import Embedder
55
from ._hash import Hasher
6-
from .context import Context, load_dataset
6+
from ._logging import setup_logging
77
from ._pipeline import Pipeline
8+
from ._ranker import Ranker
9+
from ._vector_index import VectorIndex
10+
from .context import Context, load_dataset
811

9-
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline", "load_dataset"]
12+
__all__ = [
13+
"Context",
14+
"Dataset",
15+
"Embedder",
16+
"Hasher",
17+
"Pipeline",
18+
"Ranker",
19+
"VectorIndex",
20+
"load_dataset",
21+
"setup_logging",
22+
]

autointent/_callbacks/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4242
:param kwargs: Data to log.
4343
"""
4444

45+
@abstractmethod
46+
def log_metrics(self, metrics: dict[str, Any]) -> None:
47+
"""
48+
Log metrics during training.
49+
50+
:param metrics: Metrics to log.
51+
"""
52+
4553
@abstractmethod
4654
def end_module(self) -> None:
4755
"""End a module."""

autointent/_callbacks/callback_handler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
4444
"""
4545
self.call_events("log_value", **kwargs)
4646

47+
def log_metrics(self, metrics: dict[str, Any]) -> None:
48+
"""
49+
Log metrics during training.
50+
51+
:param metrics: Metrics to log.
52+
"""
53+
self.call_events("log_metrics", metrics=metrics)
54+
4755
def end_module(self) -> None:
4856
"""End a module."""
4957
self.call_events("end_module")

autointent/_callbacks/tensorboard.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
7373
else:
7474
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
7575

76+
def log_metrics(self, metrics: dict[str, Any]) -> None:
77+
"""
78+
Log metrics during training.
79+
80+
:param metrics: Metrics to log.
81+
"""
82+
if self.module_writer is None:
83+
msg = "start_run must be called before log_value."
84+
raise RuntimeError(msg)
85+
86+
for key, value in metrics.items():
87+
if isinstance(value, int | float):
88+
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
89+
else:
90+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
91+
7692
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
7793
"""
7894
Log final metrics.

autointent/_callbacks/wandb.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
5959
"""
6060
self.wandb.log(kwargs)
6161

62+
def log_metrics(self, metrics: dict[str, Any]) -> None:
63+
"""
64+
Log metrics during training.
65+
66+
:param metrics: Metrics to log.
67+
"""
68+
self.wandb.log(metrics)
69+
6270
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
6371
"""
6472
Log final metrics.

autointent/_dataset/_dataset.py

Lines changed: 12 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66
from pathlib import Path
77
from typing import Any, TypedDict
88

9-
from datasets import ClassLabel, Sequence, concatenate_datasets, get_dataset_config_names, load_dataset
109
from datasets import Dataset as HFDataset
10+
from datasets import Sequence, get_dataset_config_names, load_dataset
1111

12-
from autointent.custom_types import LabelType, Split
12+
from autointent.custom_types import LabelWithOOS, Split
1313
from autointent.schemas import Intent, Tag
1414

1515

1616
class Sample(TypedDict):
1717
"""
1818
Typed dictionary representing a dataset sample.
1919
20-
:param str utterance: The text of the utterance.
21-
:param LabelType | None label: The label associated with the utterance, or None if out-of-scope.
20+
:param utterance: The text of the utterance.
21+
:param label: The label associated with the utterance, or None if out-of-scope.
2222
"""
2323

2424
utterance: str
25-
label: LabelType | None
25+
label: LabelWithOOS
2626

2727

2828
class Dataset(dict[str, HFDataset]):
@@ -39,7 +39,7 @@ class Dataset(dict[str, HFDataset]):
3939

4040
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
4141
"""
42-
Initialize the dataset and configure OOS split if applicable.
42+
Initialize the dataset.
4343
4444
:param args: Positional arguments to initialize the dataset.
4545
:param intents: List of intents associated with the dataset.
@@ -49,15 +49,6 @@ def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: #
4949

5050
self.intents = intents
5151

52-
self._encoded_labels = False
53-
54-
if self.multilabel:
55-
self._encode_labels()
56-
57-
oos_split = self._create_oos_split()
58-
if oos_split is not None:
59-
self[Split.OOS] = oos_split
60-
6152
@property
6253
def multilabel(self) -> bool:
6354
"""
@@ -125,7 +116,6 @@ def to_multilabel(self) -> "Dataset":
125116
"""
126117
for split_name, split in self.items():
127118
self[split_name] = split.map(self._to_multilabel)
128-
self._encode_labels()
129119
return self
130120

131121
def to_dict(self) -> dict[str, list[dict[str, Any]]]:
@@ -184,38 +174,15 @@ def get_n_classes(self, split: str) -> int:
184174
"""
185175
classes = set()
186176
for label in self[split][self.label_feature]:
187-
match (label, self._encoded_labels):
188-
case (int(), _):
177+
match label:
178+
case int():
189179
classes.add(label)
190-
case (list(), False):
191-
for label_ in label:
192-
classes.add(label_)
193-
case (list(), True):
180+
case list():
194181
for idx, label_ in enumerate(label):
195182
if label_:
196183
classes.add(idx)
197184
return len(classes)
198185

199-
def _encode_labels(self) -> "Dataset":
200-
"""
201-
Encode dataset labels into one-hot or multilabel format.
202-
203-
:return: Self, with labels encoded.
204-
"""
205-
for split_name, split in self.items():
206-
self[split_name] = split.map(self._encode_label)
207-
self._encoded_labels = True
208-
return self
209-
210-
def _is_oos(self, sample: Sample) -> bool:
211-
"""
212-
Check if a sample is out-of-scope.
213-
214-
:param sample: The sample to check.
215-
:return: True if the sample is out-of-scope, False otherwise.
216-
"""
217-
return sample["label"] is None
218-
219186
def _to_multilabel(self, sample: Sample) -> Sample:
220187
"""
221188
Convert a sample's label to multilabel format.
@@ -224,50 +191,7 @@ def _to_multilabel(self, sample: Sample) -> Sample:
224191
:return: Sample with label in multilabel format.
225192
"""
226193
if isinstance(sample["label"], int):
227-
sample["label"] = [sample["label"]]
228-
return sample
229-
230-
def _encode_label(self, sample: Sample) -> Sample:
231-
"""
232-
Encode a sample's label as a one-hot vector.
233-
234-
:param sample: The sample to encode.
235-
:return: Sample with encoded label.
236-
"""
237-
one_hot_label = [0] * self.n_classes
238-
match sample["label"]:
239-
case int():
240-
one_hot_label[sample["label"]] = 1
241-
case list():
242-
for idx in sample["label"]:
243-
one_hot_label[idx] = 1
244-
sample["label"] = one_hot_label
194+
ohe_vector = [0] * self.n_classes
195+
ohe_vector[sample["label"]] = 1
196+
sample["label"] = ohe_vector
245197
return sample
246-
247-
def _create_oos_split(self) -> HFDataset | None:
248-
"""
249-
Create an out-of-scope (OOS) split from the dataset.
250-
251-
:return: The OOS split if created, None otherwise.
252-
"""
253-
oos_splits = [split.filter(self._is_oos) for split in self.values()]
254-
oos_splits = [oos_split for oos_split in oos_splits if oos_split.num_rows]
255-
if oos_splits:
256-
for split_name, split in self.items():
257-
self[split_name] = split.filter(lambda sample: not self._is_oos(sample))
258-
return concatenate_datasets(oos_splits)
259-
return None
260-
261-
def _cast_label_feature(self) -> None:
262-
"""Cast the label feature of the dataset to the appropriate type."""
263-
for split_name, split in self.items():
264-
new_features = split.features.copy()
265-
if self.multilabel:
266-
new_features[self.label_feature] = Sequence(
267-
ClassLabel(num_classes=self.n_classes),
268-
)
269-
else:
270-
new_features[self.label_feature] = ClassLabel(
271-
num_classes=self.n_classes,
272-
)
273-
self[split_name] = split.cast(new_features)

autointent/_dataset/_validation.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,9 @@ def validate_dataset(self) -> "DatasetReader":
6666
]
6767
splits = [split for split in splits if split]
6868

69-
n_classes = [self._get_n_classes(split) for split in splits]
70-
if len(set(n_classes)) != 1:
71-
message = (
72-
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
73-
"Ensure all splits have the same number of classes."
74-
)
75-
raise ValueError(message)
76-
if not n_classes[0]:
77-
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
78-
raise ValueError(message)
69+
n_classes = self._validate_classes(splits)
7970

80-
self._validate_intents(n_classes[0])
71+
self._validate_intents(n_classes)
8172

8273
for split in splits:
8374
self._validate_split(split)
@@ -100,6 +91,20 @@ def _get_n_classes(self, split: list[Sample]) -> int:
10091
classes.add(label)
10192
return len(classes)
10293

94+
def _validate_classes(self, splits: list[list[Sample]]) -> int:
95+
"""Validate that each split has all classes."""
96+
n_classes = [self._get_n_classes(split) for split in splits]
97+
if len(set(n_classes)) != 1:
98+
message = (
99+
f"Mismatch in number of classes across splits. Found class counts: {n_classes}. "
100+
"Ensure all splits have the same number of classes."
101+
)
102+
raise ValueError(message)
103+
if not n_classes[0]:
104+
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
105+
raise ValueError(message)
106+
return n_classes[0]
107+
103108
def _validate_intents(self, n_classes: int) -> "DatasetReader":
104109
"""
105110
Validate the intents by checking their IDs for sequential order.
@@ -132,7 +137,8 @@ def _validate_split(self, split: list[Sample]) -> "DatasetReader":
132137
intent_ids = {intent.id for intent in self.intents}
133138
for sample in split:
134139
message = (
135-
f"Sample with label {sample.label} references a non-existent intent ID. " f"Valid IDs are {intent_ids}."
140+
f"Sample with label {sample.label} and utterance {sample.utterance[:10]}... "
141+
f"references a non-existent intent ID. Valid IDs are {intent_ids}."
136142
)
137143
if isinstance(sample.label, int) and sample.label not in intent_ids:
138144
raise ValueError(message)

0 commit comments

Comments
 (0)