1010import tempfile
1111from functools import lru_cache
1212from pathlib import Path
13+ from uuid import uuid4
1314
1415import huggingface_hub
1516import numpy as np
2627
2728from autointent ._hash import Hasher
2829from autointent .configs import EmbedderConfig , EmbedderFineTuningConfig , TaskTypeEnum
30+ from autointent .custom_types import ListOfLabels
2931
3032logger = logging .getLogger (__name__ )
3133
@@ -72,7 +74,9 @@ class Embedder:
7274 """
7375
7476 _metadata_dict_name : str = "metadata.json"
77+ _weights_dir_name : str = "sentence_transformer"
7578 _dump_dir : Path | None = None
79+ _trained : bool = False
7680
7781 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
7882 """Initialize the Embedder.
@@ -89,7 +93,7 @@ def _get_hash(self) -> int:
8993 The hash value of the Embedder.
9094 """
9195 hasher = Hasher ()
92- if self .config .freeze :
96+ if not Path ( self .config .model_name ). exists () :
9397 commit_hash = _get_latest_commit_hash (self .config .model_name )
9498 hasher .update (commit_hash )
9599 else :
@@ -113,8 +117,22 @@ def _load_model(self) -> SentenceTransformer:
113117 res = self .embedding_model
114118 return res
115119
116- def train (self , utterances : list [str ], labels : list [ int ] , config : EmbedderFineTuningConfig ) -> None :
120+ def train (self , utterances : list [str ], labels : ListOfLabels , config : EmbedderFineTuningConfig ) -> None :
117121 """Train the embedding model."""
122+ if len (utterances ) != len (labels ):
123+ msg = f"Utterances and labels lists lengths mismatch: { len (utterances )= } != { len (labels )= } "
124+ raise ValueError (msg )
125+
126+ if len (labels ) == 0 :
127+ msg = "Empty data"
128+ raise ValueError (msg )
129+
130+ # TODO support multi-label data
131+ if isinstance (labels [0 ], list ):
132+ msg = "Multi-label data is not supported for embeddings fine-tuning for now"
133+ logger .warning (msg )
134+ return
135+
118136 self ._load_model ()
119137 if config .early_stopping :
120138 x_train , x_val , y_train , y_val = train_test_split (utterances , labels , test_size = 0.1 , random_state = 42 )
@@ -131,8 +149,7 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
131149 output_dir = tmp_dir ,
132150 num_train_epochs = config .epoch_num ,
133151 per_device_train_batch_size = config .batch_size ,
134- per_device_eval_batch_size = 8 ,
135- eval_steps = 1 ,
152+ per_device_eval_batch_size = config .batch_size ,
136153 learning_rate = config .learning_rate ,
137154 warmup_ratio = config .warmup_ratio ,
138155 fp16 = config .fp16 ,
@@ -143,9 +160,9 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
143160 eval_strategy = "epoch" ,
144161 greater_is_better = False ,
145162 )
146- callback : list [TrainerCallback ] = []
163+ callbacks : list [TrainerCallback ] = []
147164 if config .early_stopping :
148- callback .append (
165+ callbacks .append (
149166 EarlyStoppingCallback (
150167 early_stopping_patience = config .early_stopping ,
151168 early_stopping_threshold = config .early_stopping_threshold ,
@@ -157,11 +174,18 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
157174 train_dataset = tr_ds ,
158175 eval_dataset = val_ds ,
159176 loss = loss ,
160- callbacks = callback ,
177+ callbacks = callbacks ,
161178 )
162179
163180 trainer .train ()
164181
182+ # use temporary path for re-usage
183+ model_path = str (Path (tempfile .mkdtemp ("autointent_embedders" )) / str (uuid4 ()))
184+ self .embedding_model .save (model_path )
185+ self .config .model_name = model_path
186+
187+ self ._trained = True
188+
165189 def clear_ram (self ) -> None :
166190 """Move the embedding model to CPU and delete it from memory."""
167191 if hasattr (self , "embedding_model" ):
@@ -182,6 +206,11 @@ def dump(self, path: Path) -> None:
182206 Args:
183207 path: Path to the directory where the model will be saved.
184208 """
209+ if self ._trained :
210+ model_path = str ((path / self ._weights_dir_name ).resolve ())
211+ self .embedding_model .save (model_path , create_model_card = False )
212+ self .config .model_name = model_path
213+
185214 self ._dump_dir = path
186215 path .mkdir (parents = True , exist_ok = True )
187216 with (path / self ._metadata_dict_name ).open ("w" ) as file :
0 commit comments