1010import numpy .typing as npt
1111import torch
1212from datasets import Dataset
13- from sentence_transformers import SentenceTransformer , SentenceTransformerTrainer , SentenceTransformerTrainingArguments
14- from sentence_transformers .losses import BatchAllTripletLoss
15- from sentence_transformers .training_args import BatchSamplers
1613from sklearn .model_selection import train_test_split
1714
1815from autointent ._hash import Hasher
2522from .utils import get_embeddings_path
2623
2724if TYPE_CHECKING :
25+ from sentence_transformers import SentenceTransformer
2826 from transformers import TrainerCallback
2927
3028logger = logging .getLogger (__name__ )
@@ -51,6 +49,7 @@ class SentenceTransformerEmbeddingBackend(BaseEmbeddingBackend):
5149 """SentenceTransformer-based embedding backend implementation."""
5250
5351 supports_training : bool = True
52+ _model : "SentenceTransformer | None"
5453
5554 def __init__ (self , config : SentenceTransformerEmbeddingConfig ) -> None :
5655 """Initialize the SentenceTransformer backend.
@@ -59,7 +58,7 @@ def __init__(self, config: SentenceTransformerEmbeddingConfig) -> None:
5958 config: Configuration for SentenceTransformer embeddings.
6059 """
6160 self .config = config
62- self ._model : SentenceTransformer | None = None
61+ self ._model = None
6362 self ._trained : bool = False
6463
6564 def clear_ram (self ) -> None :
@@ -71,10 +70,12 @@ def clear_ram(self) -> None:
7170 self ._model = None
7271 torch .cuda .empty_cache ()
7372
74- def _load_model (self ) -> SentenceTransformer :
73+ def _load_model (self ) -> " SentenceTransformer" :
7574 """Load sentence transformers model to device."""
7675 if self ._model is None :
77- res = SentenceTransformer (
76+ # Lazy import sentence-transformers
77+ st = require ("sentence_transformers" , extra = "sentence-transformers" )
78+ res = st .SentenceTransformer (
7879 self .config .model_name ,
7980 device = self .config .device ,
8081 prompts = self .config .get_prompt_config (),
@@ -231,16 +232,17 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
231232
232233 model = self ._load_model ()
233234
235+ # Lazy import sentence-transformers training components (only needed for fine-tuning)
236+ st = require ("sentence_transformers" , extra = "sentence-transformers" )
237+ transformers = require ("transformers" , extra = "transformers" )
238+
234239 x_train , x_val , y_train , y_val = train_test_split (utterances , labels , test_size = config .val_fraction )
235240 tr_ds = Dataset .from_dict ({"text" : x_train , "label" : y_train })
236241 val_ds = Dataset .from_dict ({"text" : x_val , "label" : y_val })
237242
238- loss = BatchAllTripletLoss (model = model , margin = config .margin )
243+ loss = st . losses . BatchAllTripletLoss (model = model , margin = config .margin )
239244 with tempfile .TemporaryDirectory () as tmp_dir :
240- # Lazy import transformers (only needed for fine-tuning)
241- transformers = require ("transformers" , extra = "transformers" )
242-
243- args = SentenceTransformerTrainingArguments (
245+ args = st .SentenceTransformerTrainingArguments (
244246 save_strategy = "epoch" ,
245247 save_total_limit = 1 ,
246248 output_dir = tmp_dir ,
@@ -251,7 +253,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
251253 warmup_ratio = config .warmup_ratio ,
252254 fp16 = config .fp16 ,
253255 bf16 = config .bf16 ,
254- batch_sampler = BatchSamplers .NO_DUPLICATES ,
256+ batch_sampler = st . training_args . BatchSamplers .NO_DUPLICATES ,
255257 metric_for_best_model = "eval_loss" ,
256258 load_best_model_at_end = True ,
257259 eval_strategy = "epoch" ,
@@ -263,7 +265,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
263265 early_stopping_threshold = config .early_stopping_threshold ,
264266 )
265267 ]
266- trainer = SentenceTransformerTrainer (
268+ trainer = st . SentenceTransformerTrainer (
267269 model = model ,
268270 args = args ,
269271 train_dataset = tr_ds ,
0 commit comments