1010from functools import lru_cache
1111from pathlib import Path
1212from typing import TypedDict
13+ import tempfile
1314
1415import huggingface_hub
1516import numpy as np
1617import numpy .typing as npt
1718import torch
1819from appdirs import user_cache_dir
19- from sentence_transformers import SentenceTransformer , SentenceTransformerTrainer , SentenceTransformerTrainingArguments , InputExample
20+ from sentence_transformers import SentenceTransformer , SentenceTransformerTrainer , SentenceTransformerTrainingArguments
2021from sentence_transformers .similarity_functions import SimilarityFunction
2122from sentence_transformers .losses import BatchAllTripletLoss
2223from sentence_transformers .training_args import BatchSamplers
2324from datasets import Dataset
2425
2526
2627from autointent ._hash import Hasher
27- from autointent .configs import EmbedderConfig , TaskTypeEnum
28+ from autointent .configs import EmbedderConfig , TaskTypeEnum , EmbedderFineTuningConfig
2829
2930logger = logging .getLogger (__name__ )
3031
@@ -126,7 +127,7 @@ def _load_model(self) -> None:
126127 similarity_fn_name = self .config .similarity_fn_name ,
127128 trust_remote_code = self .config .trust_remote_code ,
128129 )
129- def train (self , utterances : list [str ], labels : list [int ], ** kwargs ) -> None :
130+ def train (self , utterances : list [str ], labels : list [int ], config : EmbedderFineTuningConfig ) -> None :
130131 """Train the embedding model"""
131132 self ._load_model ()
132133
@@ -137,31 +138,29 @@ def train(self, utterances: list[str], labels: list[int], **kwargs) -> None:
137138
138139 loss = BatchAllTripletLoss (
139140 model = self .embedding_model ,
140- margin = kwargs .get ("margin" , 0.5 )
141- )
142-
143- args = SentenceTransformerTrainingArguments (
144- save_strategy = "no" ,
145- output_dir = kwargs ['out_dir' ],
146- num_train_epochs = kwargs ['epoch_num' ],
147- per_device_train_batch_size = self .config .batch_size ,
148- learning_rate = kwargs .get ("learning_rate" , 2e-5 ),
149- warmup_ratio = kwargs .get ("warmup_ratio" , 0.1 ),
150- fp16 = kwargs .get ("fp16" , True ),
151- bf16 = kwargs .get ("bf16" , False ),
152- batch_sampler = BatchSamplers .NO_DUPLICATES ,
141+ margin = config .margin
153142 )
143+ with tempfile .TemporaryDirectory () as tmp_dir :
144+ args = SentenceTransformerTrainingArguments (
145+ save_strategy = "no" ,
146+ output_dir = tmp_dir ,
147+ num_train_epochs = config .epoch_num ,
148+ per_device_train_batch_size = self .config .batch_size ,
149+ learning_rate = config .learning_rate ,
150+ warmup_ratio = config .warmup_ratio ,
151+ fp16 = config .fp16 ,
152+ bf16 = config .bf16 ,
153+ batch_sampler = BatchSamplers .NO_DUPLICATES ,
154+ )
154155
155- trainer = SentenceTransformerTrainer (
156- model = self .embedding_model ,
157- args = args ,
158- train_dataset = tr_ds ,
159- loss = loss ,
160- )
161-
162- trainer .train ()
163-
164- self .embedding_model .save (kwargs ['out_dir' ])
156+ trainer = SentenceTransformerTrainer (
157+ model = self .embedding_model ,
158+ args = args ,
159+ train_dataset = tr_ds ,
160+ loss = loss ,
161+ )
162+
163+ trainer .train ()
165164
166165 def clear_ram (self ) -> None :
167166 """Move the embedding model to CPU and delete it from memory."""
0 commit comments