Skip to content

Commit 08b3a55

Browse files
authored
Pass output_dir to superclass (#585)
1 parent 0451b92 commit 08b3a55

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/setfit/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sentence_transformers import SentenceTransformerTrainer, losses
99
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
1010
from sentence_transformers.model_card import ModelCardCallback as STModelCardCallback
11-
from sentence_transformers.training_args import BatchSamplers
11+
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
1212
from sklearn.preprocessing import LabelEncoder
1313
from torch import nn
1414
from transformers import __version__ as transformers_version
@@ -47,7 +47,11 @@ def __init__(
4747
self._setfit_model = setfit_model
4848
self._setfit_args = setfit_args
4949
self.logs_prefix = "embedding"
50-
super().__init__(model=setfit_model.model_body, **kwargs)
50+
super().__init__(
51+
model=setfit_model.model_body,
52+
args=SentenceTransformerTrainingArguments(output_dir=setfit_args.output_dir),
53+
**kwargs,
54+
)
5155
self._apply_training_arguments(setfit_args)
5256

5357
for callback in list(self.callback_handler.callbacks):

0 commit comments

Comments
 (0)