77import json
88import logging
99import shutil
10+ import tempfile
1011from functools import lru_cache
1112from pathlib import Path
1213from typing import TypedDict
13- import tempfile
1414
1515import huggingface_hub
1616import numpy as np
1717import numpy .typing as npt
1818import torch
1919from appdirs import user_cache_dir
20+ from datasets import Dataset
2021from sentence_transformers import SentenceTransformer , SentenceTransformerTrainer , SentenceTransformerTrainingArguments
21- from sentence_transformers .similarity_functions import SimilarityFunction
2222from sentence_transformers .losses import BatchAllTripletLoss
23+ from sentence_transformers .similarity_functions import SimilarityFunction
2324from sentence_transformers .training_args import BatchSamplers
24- from datasets import Dataset
25-
2625
2726from autointent ._hash import Hasher
28- from autointent .configs import EmbedderConfig , TaskTypeEnum , EmbedderFineTuningConfig
27+ from autointent .configs import EmbedderConfig , EmbedderFineTuningConfig , TaskTypeEnum
2928
3029logger = logging .getLogger (__name__ )
3130
@@ -128,7 +127,7 @@ def _load_model(self) -> None:
128127 trust_remote_code = self .config .trust_remote_code ,
129128 )
130129 def train (self , utterances : list [str ], labels : list [int ], config : EmbedderFineTuningConfig ) -> None :
131- """Train the embedding model"""
130+ """Train the embedding model. """
132131 self ._load_model ()
133132
134133 tr_ds = Dataset .from_dict ({
@@ -137,7 +136,7 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
137136 })
138137
139138 loss = BatchAllTripletLoss (
140- model = self .embedding_model ,
139+ model = self .embedding_model ,
141140 margin = config .margin
142141 )
143142 with tempfile .TemporaryDirectory () as tmp_dir :
@@ -159,9 +158,9 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
159158 train_dataset = tr_ds ,
160159 loss = loss ,
161160 )
162-
161+
163162 trainer .train ()
164-
163+
165164 def clear_ram (self ) -> None :
166165 """Move the embedding model to CPU and delete it from memory."""
167166 if hasattr (self , "embedding_model" ):
0 commit comments