Skip to content

Commit 4dcd54e

Browse files
authored
Feat/bert early stopping (#223)
* change how `clear_cache` is called * first version of early stopping * change mypy version * train_test_split bug fix * add `compute_metrics` and `EarlyStoppingCallback` * bug fix * fix mypy * try to fix `"eval_f1" not found` error * forgot to upd `from_context` * try to fix mypy * ty to fix "not found f1" error * refactor a little bit * disable early stopping for lora * fix typing errors * update contributing and makefile * minor change * use our metrics * add docstrings * set 3.10 for mypy * upd contributing.md * try to fix bug * try to fix typing issue * try to fix * add early stopping to ptuning
1 parent 2d7c380 commit 4dcd54e

File tree

22 files changed

+237
-102
lines changed

22 files changed

+237
-102
lines changed

CONTRIBUTING.md

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,73 @@
11
# Contribute to AutoIntent
22

3-
## Минимальная конфигурация
3+
## Minimum Configuration
44

5-
Мы используем `poetry` в качесте менеджера зависимостей и упаковщика.
5+
We use `poetry` as our dependency manager and packager.
66

7-
1. Установить `poetry`. Советуем обратиться к разделу официальной документации [Installation with the official installer](https://python-poetry.org/docs/#installing-with-the-official-installer). Если кратко, то достаточно просто запустить команду:
7+
1. Install `poetry`. We recommend referring to the official documentation section [Installation with the official installer](https://python-poetry.org/docs/#installing-with-the-official-installer). In short, you just need to run:
88
```bash
99
curl -sSL https://install.python-poetry.org | python3 -
1010
```
1111

12-
2. Склонировать проект, перейти в корень
12+
2. Clone the project and navigate to the root directory
1313

14-
3. Установить проект со всеми зависимостями:
14+
3. Install the project with all dependencies:
1515
```bash
1616
make install
1717
```
1818

19-
## Дополнительно
19+
## Additional Setup
2020

21-
Чтобы удобнее трекать ошибки в кодстайле, советуем установить расширение ruff для IDE. Например, для VSCode
21+
To make it easier to track code style errors, we recommend installing the ruff extension for your IDE. For example, for VSCode:
2222
```
2323
https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff
2424
```
25-
С этим расширением ошибки в кодстайле будут подчеркиваться прямо в редакторе.
25+
With this extension, code style errors will be underlined directly in the editor.
2626

27-
В корень проекта добавлен файл `.vscode/settings.json`, который указывает расширению путь к конфигу линтера.
27+
A `.vscode/settings.json` file has been added to the project root, which points the extension to the linter configuration.
2828

2929
## Contribute
3030

31-
1. Создать ветку, в которой вы будете работать. Чтобы остальным было проще понимать характер вашего контрибьюта, нужно давать краткие, но понятные названия начинающиеся. Советем начинать названия на `feat/` для веток с новыми фичами, `fix/` для исправления багов, `refactor/` для рефакторинга, `test/` для добавления тестов.
31+
1. Create a branch for your work. To make it easier for others to understand the nature of your contribution, use brief but clear names. We recommend starting branch names with `feat/` for new features, `fix/` for bug fixes, `refactor/` for refactoring, and `test/` for adding tests.
3232

33-
2. Коммит, коммит, коммит, коммит
33+
2. Commit, commit, commit, commit
3434

35-
3. Если есть новые фичи, желательно добавить для них тесты в директорию [tests](./tests).
35+
3. If there are new features, it's advisable to add tests for them in the [tests](./tests) directory.
3636

37-
4. Проверить, что внесенные изменения не ломают имеющиеся фичи
37+
4. You can open a PR!
38+
39+
Every commit in any PR triggers github actions with automated tests. All checks block merging into the main branch (with rare exceptions).
40+
41+
Sometimes waiting for CI can be long, and sometimes it's more convenient to run individual tests:
42+
- Check that your changes don't break existing features
3843
```bash
3944
make test
4045
```
41-
42-
5. Проверить кодстайл
46+
Or run a specific test (using `test_bert.py` as an example):
47+
```bash
48+
poetry run pytest tests/modules/scoring/test_bert.py
49+
```
50+
- Check code style (it also applies formatter)
4351
```bash
4452
make lint
4553
```
54+
- Check type hints:
55+
```bash
56+
make typing
57+
```
58+
Note: If mypy shows different errors locally compared to github actions, you should update your local dependencies:
59+
```bash
60+
make update
61+
```
4662

47-
6. Ура, можно открывать Pull Request!
48-
49-
## Устройство проекта
50-
51-
![](assets/dependency-graph.png)
52-
53-
## Построение документации
63+
## Building Documentation
5464

55-
Построить html версию в папке `docs/build`:
65+
Build the HTML version in the `docs/build` folder:
5666
```bash
5767
make docs
5868
```
5969

60-
Построить html версию и захостить локально:
70+
Build the HTML version and host it locally:
6171
```bash
6272
make serve-docs
6373
```

Makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ lint:
2222
$(poetry) ruff format
2323
$(poetry) ruff check --fix
2424

25-
.PHONY: sync
26-
sync:
27-
poetry sync --extras "dev test typing docs"
25+
.PHONY: update
26+
update:
27+
rm -f poetry.lock
28+
poetry install --extras "dev test typing docs"
2829

2930
.PHONY: docs
3031
docs:

assets/classification_pipeline.png

-134 KB
Binary file not shown.

assets/dependency-graph.png

-73.4 KB
Binary file not shown.

autointent/configs/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import DataConfig, LoggingConfig
5-
from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig
5+
from ._transformers import (
6+
CrossEncoderConfig,
7+
EarlyStoppingConfig,
8+
EmbedderConfig,
9+
HFModelConfig,
10+
TaskTypeEnum,
11+
TokenizerConfig,
12+
)
613

714
__all__ = [
815
"CrossEncoderConfig",
916
"DataConfig",
17+
"EarlyStoppingConfig",
1018
"EmbedderConfig",
1119
"HFModelConfig",
1220
"InferenceNodeConfig",

autointent/configs/_transformers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
55
from typing_extensions import Self, assert_never
66

7+
from autointent.custom_types import FloatFromZeroToOne
8+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
9+
710

811
class TokenizerConfig(BaseModel):
912
padding: bool | Literal["longest", "max_length", "do_not_pad"] = True
@@ -122,3 +125,21 @@ class CrossEncoderConfig(HFModelConfig):
122125
tokenizer_config: TokenizerConfig = Field(
123126
default_factory=lambda: TokenizerConfig(max_length=512)
124127
) # this is because sentence-transformers doesn't allow you to customize tokenizer settings properly
128+
129+
130+
class EarlyStoppingConfig(BaseModel):
131+
val_fraction: float = Field(
132+
0.2,
133+
description=(
134+
"Fraction of train samples to allocate to dev set to monitor quality "
135+
"during training and perofrm early stopping if quality doesn't enhances."
136+
),
137+
)
138+
patience: PositiveInt = Field(1, description="Maximum number of epoches to wait for quality to enhance.")
139+
threshold: FloatFromZeroToOne = Field(
140+
0.0,
141+
description="Minimum quality increment to count it as enhancement. Default: any incremeant is counted",
142+
)
143+
metric: Literal[tuple((SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS).keys())] | None = Field( # type: ignore[valid-type]
144+
"scoring_f1", description="Metric to monitor."
145+
)

autointent/metrics/retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE
539539
query_label_, candidates_labels_ = transform(query_labels, candidates_labels)
540540

541541
ndcg_scores: list[float] = []
542-
relevance_scores: npt.NDArray[np.bool] = query_label_[:, None] == candidates_labels_
542+
relevance_scores = query_label_[:, None] == candidates_labels_
543543

544544
for rel_scores in relevance_scores:
545545
cur_dcg = _dcg(rel_scores, k)

autointent/modules/base/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def score_metrics_cv( # type: ignore[no-untyped-def]
187187
all_val_preds = []
188188

189189
for train_utterances, train_labels, val_utterances, val_labels in cv_iterator:
190+
self.clear_cache()
190191
self.fit(train_utterances, train_labels, **fit_kwargs) # type: ignore[arg-type]
191192
val_preds = self.predict(val_utterances)
192193
for name, fn in metrics_dict.items():

autointent/modules/embedding/_logreg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def from_context(
8181

8282
def clear_cache(self) -> None:
8383
"""Clear embedder from memory."""
84-
self._embedder.clear_ram()
84+
if hasattr(self, "_embedder"):
85+
self._embedder.clear_ram()
8586

8687
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8788
"""Train the logistic regression model using the provided utterances and labels.
@@ -90,9 +91,6 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
9091
utterances: List of text data to index
9192
labels: List of corresponding labels for the utterances
9293
"""
93-
if hasattr(self, "_embedder"):
94-
self.clear_cache()
95-
9694
self._validate_task(labels)
9795

9896
self._embedder = Embedder(

autointent/modules/embedding/_retrieval.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8383
utterances: List of text data to index
8484
labels: List of corresponding labels for the utterances
8585
"""
86-
if hasattr(self, "_vector_index"):
87-
self.clear_cache()
88-
8986
self._validate_task(labels)
9087

9188
self._vector_index = VectorIndex(
@@ -140,7 +137,8 @@ def get_assets(self) -> EmbeddingArtifact:
140137

141138
def clear_cache(self) -> None:
142139
"""Clear cached data in memory used by the vector index."""
143-
self._vector_index.clear_ram()
140+
if hasattr(self, "_vector_index"):
141+
self._vector_index.clear_ram()
144142

145143
def predict(self, utterances: list[str]) -> list[ListOfLabels]:
146144
"""Predict the nearest neighbors for a list of utterances.

0 commit comments

Comments
 (0)