77import json
88import logging
99import shutil
10+ import tempfile
1011from functools import lru_cache
1112from pathlib import Path
13+ from uuid import uuid4
1214
1315import huggingface_hub
1416import numpy as np
1517import numpy .typing as npt
1618import torch
1719from appdirs import user_cache_dir
18- from sentence_transformers import SentenceTransformer
20+ from datasets import Dataset
21+ from sentence_transformers import SentenceTransformer , SentenceTransformerTrainer , SentenceTransformerTrainingArguments
22+ from sentence_transformers .losses import BatchAllTripletLoss
1923from sentence_transformers .similarity_functions import SimilarityFunction
24+ from sentence_transformers .training_args import BatchSamplers
25+ from sklearn .model_selection import train_test_split
26+ from transformers import EarlyStoppingCallback , TrainerCallback
2027
2128from autointent ._hash import Hasher
22- from autointent .configs import EmbedderConfig , TaskTypeEnum
29+ from autointent .configs import EmbedderConfig , EmbedderFineTuningConfig , TaskTypeEnum
30+ from autointent .custom_types import ListOfLabels
2331
2432logger = logging .getLogger (__name__ )
2533
@@ -66,15 +74,18 @@ class Embedder:
6674 """
6775
6876 _metadata_dict_name : str = "metadata.json"
77+ _weights_dir_name : str = "sentence_transformer"
6978 _dump_dir : Path | None = None
79+ _trained : bool = False
80+ _model : SentenceTransformer
7081
7182 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
7283 """Initialize the Embedder.
7384
7485 Args:
7586 embedder_config: Config of embedder.
7687 """
77- self .config = embedder_config
88+ self .config = embedder_config . model_copy ( deep = True )
7889
7990 def _get_hash (self ) -> int :
8091 """Compute a hash value for the Embedder.
@@ -83,19 +94,19 @@ def _get_hash(self) -> int:
8394 The hash value of the Embedder.
8495 """
8596 hasher = Hasher ()
86- if self .config .freeze :
97+ if not Path ( self .config .model_name ). exists () :
8798 commit_hash = _get_latest_commit_hash (self .config .model_name )
8899 hasher .update (commit_hash )
89100 else :
90- self .embedding_model = self ._load_model ()
91- for parameter in self .embedding_model .parameters ():
101+ self ._model = self ._load_model ()
102+ for parameter in self ._model .parameters ():
92103 hasher .update (parameter .detach ().cpu ().numpy ())
93104 hasher .update (self .config .tokenizer_config .max_length )
94105 return hasher .intdigest ()
95106
96107 def _load_model (self ) -> SentenceTransformer :
97108 """Load sentence transformers model to device."""
98- if not hasattr (self , "embedding_model " ):
109+ if not hasattr (self , "_model " ):
99110 res = SentenceTransformer (
100111 self .config .model_name ,
101112 device = self .config .device ,
@@ -104,15 +115,80 @@ def _load_model(self) -> SentenceTransformer:
104115 trust_remote_code = self .config .trust_remote_code ,
105116 )
106117 else :
107- res = self .embedding_model
118+ res = self ._model
108119 return res
109120
121+ def train (self , utterances : list [str ], labels : ListOfLabels , config : EmbedderFineTuningConfig ) -> None :
122+ """Train the embedding model."""
123+ if len (utterances ) != len (labels ):
124+ msg = f"Utterances and labels lists lengths mismatch: { len (utterances )= } != { len (labels )= } "
125+ raise ValueError (msg )
126+
127+ if len (labels ) == 0 :
128+ msg = "Empty data"
129+ raise ValueError (msg )
130+
131+ # TODO support multi-label data
132+ if isinstance (labels [0 ], list ):
133+ msg = "Multi-label data is not supported for embeddings fine-tuning for now"
134+ logger .warning (msg )
135+ return
136+
137+ self ._model = self ._load_model ()
138+
139+ x_train , x_val , y_train , y_val = train_test_split (utterances , labels , test_size = config .val_fraction )
140+ tr_ds = Dataset .from_dict ({"text" : x_train , "label" : y_train })
141+ val_ds = Dataset .from_dict ({"text" : x_val , "label" : y_val })
142+
143+ loss = BatchAllTripletLoss (model = self ._model , margin = config .margin )
144+ with tempfile .TemporaryDirectory () as tmp_dir :
145+ args = SentenceTransformerTrainingArguments (
146+ save_strategy = "epoch" ,
147+ save_total_limit = 1 ,
148+ output_dir = tmp_dir ,
149+ num_train_epochs = config .epoch_num ,
150+ per_device_train_batch_size = config .batch_size ,
151+ per_device_eval_batch_size = config .batch_size ,
152+ learning_rate = config .learning_rate ,
153+ warmup_ratio = config .warmup_ratio ,
154+ fp16 = config .fp16 ,
155+ bf16 = config .bf16 ,
156+ batch_sampler = BatchSamplers .NO_DUPLICATES ,
157+ metric_for_best_model = "eval_loss" ,
158+ load_best_model_at_end = True ,
159+ eval_strategy = "epoch" ,
160+ greater_is_better = False ,
161+ )
162+ callbacks : list [TrainerCallback ] = [
163+ EarlyStoppingCallback (
164+ early_stopping_patience = config .early_stopping_patience ,
165+ early_stopping_threshold = config .early_stopping_threshold ,
166+ )
167+ ]
168+ trainer = SentenceTransformerTrainer (
169+ model = self ._model ,
170+ args = args ,
171+ train_dataset = tr_ds ,
172+ eval_dataset = val_ds ,
173+ loss = loss ,
174+ callbacks = callbacks ,
175+ )
176+
177+ trainer .train ()
178+
179+ # use temporary path for re-usage
180+ model_path = str (Path (tempfile .mkdtemp ("autointent_embedders" )) / str (uuid4 ()))
181+ self ._model .save (model_path )
182+ self .config .model_name = model_path
183+
184+ self ._trained = True
185+
110186 def clear_ram (self ) -> None :
111187 """Move the embedding model to CPU and delete it from memory."""
112- if hasattr (self , "embedding_model " ):
188+ if hasattr (self , "_model " ):
113189 logger .debug ("Clearing embedder %s from memory" , self .config .model_name )
114- self .embedding_model .cpu ()
115- del self .embedding_model
190+ self ._model .cpu ()
191+ del self ._model
116192 torch .cuda .empty_cache ()
117193
118194 def delete (self ) -> None :
@@ -127,6 +203,11 @@ def dump(self, path: Path) -> None:
127203 Args:
128204 path: Path to the directory where the model will be saved.
129205 """
206+ if self ._trained :
207+ model_path = str ((path / self ._weights_dir_name ).resolve ())
208+ self ._model .save (model_path , create_model_card = False )
209+ self .config .model_name = model_path
210+
130211 self ._dump_dir = path
131212 path .mkdir (parents = True , exist_ok = True )
132213 with (path / self ._metadata_dict_name ).open ("w" ) as file :
@@ -164,6 +245,11 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
164245 Returns:
165246 A numpy array of embeddings.
166247 """
248+ if len (utterances ) == 0 :
249+ msg = "Empty input"
250+ logger .error (msg )
251+ raise ValueError (msg )
252+
167253 prompt = self .config .get_prompt (task_type )
168254
169255 if self .config .use_cache :
@@ -179,7 +265,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
179265 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
180266 return np .load (embeddings_path ) # type: ignore[no-any-return]
181267
182- self .embedding_model = self ._load_model ()
268+ self ._model = self ._load_model ()
183269
184270 logger .debug (
185271 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s" ,
@@ -191,9 +277,9 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
191277 )
192278
193279 if self .config .tokenizer_config .max_length is not None :
194- self .embedding_model .max_seq_length = self .config .tokenizer_config .max_length
280+ self ._model .max_seq_length = self .config .tokenizer_config .max_length
195281
196- embeddings = self .embedding_model .encode (
282+ embeddings = self ._model .encode (
197283 utterances ,
198284 convert_to_numpy = True ,
199285 batch_size = self .config .batch_size ,
0 commit comments