Skip to content

Commit a96c27d

Browse files
committed
correct some syntax isues
1 parent 3052628 commit a96c27d

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

autointent/_wrappers/embedder.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,24 @@
77
import json
88
import logging
99
import shutil
10+
import tempfile
1011
from functools import lru_cache
1112
from pathlib import Path
1213
from typing import TypedDict
13-
import tempfile
1414

1515
import huggingface_hub
1616
import numpy as np
1717
import numpy.typing as npt
1818
import torch
1919
from appdirs import user_cache_dir
20+
from datasets import Dataset
2021
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
21-
from sentence_transformers.similarity_functions import SimilarityFunction
2222
from sentence_transformers.losses import BatchAllTripletLoss
23+
from sentence_transformers.similarity_functions import SimilarityFunction
2324
from sentence_transformers.training_args import BatchSamplers
24-
from datasets import Dataset
25-
2625

2726
from autointent._hash import Hasher
28-
from autointent.configs import EmbedderConfig, TaskTypeEnum, EmbedderFineTuningConfig
27+
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum
2928

3029
logger = logging.getLogger(__name__)
3130

@@ -128,7 +127,7 @@ def _load_model(self) -> None:
128127
trust_remote_code=self.config.trust_remote_code,
129128
)
130129
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
131-
"""Train the embedding model"""
130+
"""Train the embedding model."""
132131
self._load_model()
133132

134133
tr_ds = Dataset.from_dict({
@@ -137,7 +136,7 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
137136
})
138137

139138
loss = BatchAllTripletLoss(
140-
model=self.embedding_model,
139+
model=self.embedding_model,
141140
margin=config.margin
142141
)
143142
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -159,9 +158,9 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
159158
train_dataset=tr_ds,
160159
loss=loss,
161160
)
162-
161+
163162
trainer.train()
164-
163+
165164
def clear_ram(self) -> None:
166165
"""Move the embedding model to CPU and delete it from memory."""
167166
if hasattr(self, "embedding_model"):

autointent/configs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from ._transformers import (
77
CrossEncoderConfig,
88
EarlyStoppingConfig,
9-
EmbedderFineTuningConfig,
109
EmbedderConfig,
10+
EmbedderFineTuningConfig,
1111
HFModelConfig,
1212
TaskTypeEnum,
1313
TokenizerConfig,
@@ -17,8 +17,8 @@
1717
"CrossEncoderConfig",
1818
"DataConfig",
1919
"EarlyStoppingConfig",
20-
"EmbedderFineTuningConfig",
2120
"EmbedderConfig",
21+
"EmbedderFineTuningConfig",
2222
"HFModelConfig",
2323
"HPOConfig",
2424
"InferenceNodeConfig",

tests/embedder/test_fine_tuning.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from autointent.context.data_handler import DataHandler
2-
from autointent._wrappers.embedder import Embedder
3-
from autointent.configs._transformers import HFModelConfig, EmbedderConfig, EmbedderFineTuningConfig
41
import numpy as np
52

3+
from autointent._wrappers.embedder import Embedder
4+
from autointent.configs._transformers import EmbedderConfig, EmbedderFineTuningConfig, HFModelConfig
5+
from autointent.context.data_handler import DataHandler
6+
7+
68
def test_model_updates_after_training(dataset):
79
"""Test that model weights actually change after training"""
810
data_handler = DataHandler(dataset)
@@ -48,7 +50,6 @@ def test_model_updates_after_training(dataset):
4850

4951
weights_changed = any(
5052
not np.allclose(orig, trained, atol=1e-6)
51-
for orig, trained in zip(original_weights, trained_weights)
53+
for orig, trained in zip(original_weights, trained_weights, strict=True)
5254
)
53-
5455
assert weights_changed, "Model weights should change after training"

0 commit comments

Comments
 (0)