Skip to content

Commit 60b1bd8

Browse files
committed
pull dev
2 parents 3d10192 + 95be22a commit 60b1bd8

37 files changed

+388
-147
lines changed

CONTRIBUTING.md

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,73 @@
11
# Contribute to AutoIntent
22

3-
## Минимальная конфигурация
3+
## Minimum Configuration
44

5-
Мы используем `poetry` в качесте менеджера зависимостей и упаковщика.
5+
We use `poetry` as our dependency manager and packager.
66

7-
1. Установить `poetry`. Советуем обратиться к разделу официальной документации [Installation with the official installer](https://python-poetry.org/docs/#installing-with-the-official-installer). Если кратко, то достаточно просто запустить команду:
7+
1. Install `poetry`. We recommend referring to the official documentation section [Installation with the official installer](https://python-poetry.org/docs/#installing-with-the-official-installer). In short, you just need to run:
88
```bash
99
curl -sSL https://install.python-poetry.org | python3 -
1010
```
1111

12-
2. Склонировать проект, перейти в корень
12+
2. Clone the project and navigate to the root directory
1313

14-
3. Установить проект со всеми зависимостями:
14+
3. Install the project with all dependencies:
1515
```bash
1616
make install
1717
```
1818

19-
## Дополнительно
19+
## Additional Setup
2020

21-
Чтобы удобнее трекать ошибки в кодстайле, советуем установить расширение ruff для IDE. Например, для VSCode
21+
To make it easier to track code style errors, we recommend installing the ruff extension for your IDE. For example, for VSCode:
2222
```
2323
https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff
2424
```
25-
С этим расширением ошибки в кодстайле будут подчеркиваться прямо в редакторе.
25+
With this extension, code style errors will be underlined directly in the editor.
2626

27-
В корень проекта добавлен файл `.vscode/settings.json`, который указывает расширению путь к конфигу линтера.
27+
A `.vscode/settings.json` file has been added to the project root, which points the extension to the linter configuration.
2828

2929
## Contribute
3030

31-
1. Создать ветку, в которой вы будете работать. Чтобы остальным было проще понимать характер вашего контрибьюта, нужно давать краткие, но понятные названия начинающиеся. Советем начинать названия на `feat/` для веток с новыми фичами, `fix/` для исправления багов, `refactor/` для рефакторинга, `test/` для добавления тестов.
31+
1. Create a branch for your work. To make it easier for others to understand the nature of your contribution, use brief but clear names. We recommend starting branch names with `feat/` for new features, `fix/` for bug fixes, `refactor/` for refactoring, and `test/` for adding tests.
3232

33-
2. Коммит, коммит, коммит, коммит
33+
2. Commit, commit, commit, commit
3434

35-
3. Если есть новые фичи, желательно добавить для них тесты в директорию [tests](./tests).
35+
3. If there are new features, it's advisable to add tests for them in the [tests](./tests) directory.
3636

37-
4. Проверить, что внесенные изменения не ломают имеющиеся фичи
37+
4. You can open a PR!
38+
39+
Every commit in any PR triggers github actions with automated tests. All checks block merging into the main branch (with rare exceptions).
40+
41+
Sometimes waiting for CI can be long, and sometimes it's more convenient to run individual tests:
42+
- Check that your changes don't break existing features
3843
```bash
3944
make test
4045
```
41-
42-
5. Проверить кодстайл
46+
Or run a specific test (using `test_bert.py` as an example):
47+
```bash
48+
poetry run pytest tests/modules/scoring/test_bert.py
49+
```
50+
- Check code style (it also applies formatter)
4351
```bash
4452
make lint
4553
```
54+
- Check type hints:
55+
```bash
56+
make typing
57+
```
58+
Note: If mypy shows different errors locally compared to github actions, you should update your local dependencies:
59+
```bash
60+
make update
61+
```
4662

47-
6. Ура, можно открывать Pull Request!
48-
49-
## Устройство проекта
50-
51-
![](assets/dependency-graph.png)
52-
53-
## Построение документации
63+
## Building Documentation
5464

55-
Построить html версию в папке `docs/build`:
65+
Build the HTML version in the `docs/build` folder:
5666
```bash
5767
make docs
5868
```
5969

60-
Построить html версию и захостить локально:
70+
Build the HTML version and host it locally:
6171
```bash
6272
make serve-docs
6373
```

Makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ lint:
2222
$(poetry) ruff format
2323
$(poetry) ruff check --fix
2424

25-
.PHONY: sync
26-
sync:
27-
poetry sync --extras "dev test typing docs"
25+
.PHONY: update
26+
update:
27+
rm -f poetry.lock
28+
poetry install --extras "dev test typing docs"
2829

2930
.PHONY: docs
3031
docs:

assets/classification_pipeline.png

-134 KB
Binary file not shown.

assets/dependency-graph.png

-73.4 KB
Binary file not shown.

autointent/_callbacks/wandb.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
8989
"""
9090
self.wandb.log(metrics)
9191

92+
def _close_current_run(self) -> None:
93+
"""Close the current W&B run if open."""
94+
if self.wandb.run is not None:
95+
self.wandb.finish()
96+
9297
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
9398
"""Logs final evaluation metrics to W&B.
9499
@@ -97,6 +102,8 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
97102
Args:
98103
metrics: A dictionary of final performance metrics.
99104
"""
105+
self._close_current_run()
106+
100107
wandb_run_init_args = {
101108
"project": self.project_name,
102109
"group": self.group,
@@ -105,11 +112,14 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
105112
}
106113

107114
try:
108-
self.wandb.init(config=metrics, **wandb_run_init_args)
115+
config = metrics["configs"]
116+
self.wandb.init(config=config, **wandb_run_init_args)
117+
self.wandb.log(metrics)
109118
except Exception as e:
110119
if "run config cannot exceed" not in str(e):
111120
# https://github.com/deeppavlov/AutoIntent/issues/202
112121
raise
122+
self._close_current_run()
113123
logger.warning("W&B run config is too large, skipping logging modules configs")
114124
logger.warning("'final_metrics' will be logged to W&B with pipeline_metrics only")
115125
logger.warning("If you want to access modules configs in future, address to the individual modules runs")

autointent/_wrappers/embedder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,14 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
188188
Returns:
189189
A numpy array of embeddings.
190190
"""
191+
prompt = self.config.get_prompt(task_type)
192+
191193
if self.config.use_cache:
192194
hasher = Hasher()
193195
hasher.update(self)
194196
hasher.update(utterances)
197+
if prompt:
198+
hasher.update(prompt)
195199

196200
embeddings_path = _get_embeddings_path(hasher.hexdigest())
197201
if embeddings_path.exists():
@@ -200,11 +204,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
200204
self._load_model()
201205

202206
logger.debug(
203-
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
207+
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",
204208
self.config.model_name,
205209
self.config.batch_size,
206210
str(self.config.tokenizer_config.max_length),
207211
self.config.device,
212+
prompt,
208213
)
209214

210215
if self.config.tokenizer_config.max_length is not None:
@@ -215,7 +220,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
215220
convert_to_numpy=True,
216221
batch_size=self.config.batch_size,
217222
normalize_embeddings=True,
218-
prompt=self.config.get_prompt_type(task_type),
223+
prompt=prompt,
219224
)
220225

221226
if self.config.use_cache:

autointent/configs/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import DataConfig, LoggingConfig
5-
from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig
5+
from ._transformers import (
6+
CrossEncoderConfig,
7+
EarlyStoppingConfig,
8+
EmbedderConfig,
9+
HFModelConfig,
10+
TaskTypeEnum,
11+
TokenizerConfig,
12+
)
613

714
__all__ = [
815
"CrossEncoderConfig",
916
"DataConfig",
17+
"EarlyStoppingConfig",
1018
"EmbedderConfig",
1119
"HFModelConfig",
1220
"InferenceNodeConfig",

autointent/configs/_transformers.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
5-
from typing_extensions import Self, assert_never
5+
from typing_extensions import Self
6+
7+
from autointent.custom_types import FloatFromZeroToOne
8+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
69

710

811
class TokenizerConfig(BaseModel):
@@ -56,7 +59,7 @@ class EmbedderConfig(HFModelConfig):
5659
default_prompt: str | None = Field(
5760
None, description="Default prompt for the model. This is used when no task specific prompt is not provided."
5861
)
59-
classifier_prompt: str | None = Field(None, description="Prompt for classifier.")
62+
classification_prompt: str | None = Field(None, description="Prompt for classifier.")
6063
cluster_prompt: str | None = Field(None, description="Prompt for clustering.")
6164
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6265
query_prompt: str | None = Field(None, description="Prompt for query.")
@@ -76,8 +79,8 @@ def get_prompt_config(self) -> dict[str, str] | None:
7679
prompts = {}
7780
if self.default_prompt:
7881
prompts[TaskTypeEnum.default.value] = self.default_prompt
79-
if self.classifier_prompt:
80-
prompts[TaskTypeEnum.classification.value] = self.classifier_prompt
82+
if self.classification_prompt:
83+
prompts[TaskTypeEnum.classification.value] = self.classification_prompt
8184
if self.cluster_prompt:
8285
prompts[TaskTypeEnum.cluster.value] = self.cluster_prompt
8386
if self.query_prompt:
@@ -88,7 +91,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
8891
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
8992
return prompts if len(prompts) > 0 else None
9093

91-
def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
94+
def get_prompt(self, prompt_type: TaskTypeEnum | None) -> str | None:
9295
"""Get the prompt type for the given task type.
9396
9497
Args:
@@ -97,21 +100,17 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
97100
Returns:
98101
The prompt for the given task type.
99102
"""
100-
if prompt_type is None:
101-
return self.default_prompt
102-
if prompt_type == TaskTypeEnum.classification:
103-
return self.classifier_prompt
104-
if prompt_type == TaskTypeEnum.cluster:
103+
if prompt_type == TaskTypeEnum.classification and self.classification_prompt is not None:
104+
return self.classification_prompt
105+
if prompt_type == TaskTypeEnum.cluster and self.classification_prompt is not None:
105106
return self.cluster_prompt
106-
if prompt_type == TaskTypeEnum.query:
107+
if prompt_type == TaskTypeEnum.query and self.query_prompt is not None:
107108
return self.query_prompt
108-
if prompt_type == TaskTypeEnum.passage:
109+
if prompt_type == TaskTypeEnum.passage and self.passage_prompt is not None:
109110
return self.passage_prompt
110-
if prompt_type == TaskTypeEnum.sts:
111+
if prompt_type == TaskTypeEnum.sts and self.sts_prompt is not None:
111112
return self.sts_prompt
112-
if prompt_type == TaskTypeEnum.default:
113-
return self.default_prompt
114-
assert_never(prompt_type)
113+
return self.default_prompt
115114

116115

117116
class CrossEncoderConfig(HFModelConfig):
@@ -122,3 +121,21 @@ class CrossEncoderConfig(HFModelConfig):
122121
tokenizer_config: TokenizerConfig = Field(
123122
default_factory=lambda: TokenizerConfig(max_length=512)
124123
) # this is because sentence-transformers doesn't allow you to customize tokenizer settings properly
124+
125+
126+
class EarlyStoppingConfig(BaseModel):
127+
val_fraction: float = Field(
128+
0.2,
129+
description=(
130+
"Fraction of train samples to allocate to dev set to monitor quality "
131+
"during training and perofrm early stopping if quality doesn't enhances."
132+
),
133+
)
134+
patience: PositiveInt = Field(1, description="Maximum number of epoches to wait for quality to enhance.")
135+
threshold: FloatFromZeroToOne = Field(
136+
0.0,
137+
description="Minimum quality increment to count it as enhancement. Default: any incremeant is counted",
138+
)
139+
metric: Literal[tuple((SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS).keys())] | None = Field( # type: ignore[valid-type]
140+
"scoring_f1", description="Metric to monitor."
141+
)

autointent/context/data_handler/_stratification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,16 @@ def __call__(
7272
ValueError: If OOS samples are present but allow_oos_in_train is not specified.
7373
"""
7474
if not self._has_oos_samples(dataset):
75-
return self._split_without_oos(dataset, multilabel, self.test_size)
75+
train, test = self._split_without_oos(dataset, multilabel, self.test_size)
76+
if self.is_few_shot:
77+
train, test = create_few_shot_split(
78+
train,
79+
test,
80+
multilabel=multilabel,
81+
label_column=self.label_feature,
82+
examples_per_label=self.examples_per_label,
83+
)
84+
return train, test
7685
if allow_oos_in_train is None:
7786
msg = (
7887
"Error while splitting dataset. It contains OOS samples, "

autointent/context/optimization_info/_optimization_info.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def dump_evaluation_results(self) -> dict[str, Any]:
225225
"pipeline_metrics": self.pipeline_metrics,
226226
"metrics": node_wise_metrics,
227227
"configs": self.trials.model_dump(),
228-
"artifacts": self.artifacts.model_dump(),
229228
}
230229

231230
def dump(self, path: Path) -> None:

0 commit comments

Comments
 (0)