Skip to content

Commit 71bf957

Browse files
committed
make a list of callbacks
1 parent c743c0b commit 71bf957

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

autointent/_wrappers/embedder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,24 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
156156
batch_sampler=BatchSamplers.NO_DUPLICATES,
157157
metric_for_best_model="eval_loss",
158158
load_best_model_at_end=True,
159-
evaluation_strategy = "epoch",
159+
evaluation_strategy="epoch",
160160
greater_is_better=False,
161161
)
162+
callback = []
163+
if config.early_stopping:
164+
callback.append(
165+
EarlyStoppingCallback(
166+
early_stopping_patience=config.early_stopping,
167+
early_stopping_threshold=config.early_stopping_threshold,
168+
)
169+
)
162170
trainer = SentenceTransformerTrainer(
163171
model=self.embedding_model,
164172
args=args,
165173
train_dataset=tr_ds,
166174
eval_dataset=val_ds,
167175
loss=loss,
168-
callbacks=EarlyStoppingCallback(
169-
early_stopping_patience=config.early_stopping,
170-
early_stopping_threshold=config.early_stopping_threshold,
171-
),
176+
callbacks=callback,
172177
)
173178

174179
trainer.train()

0 commit comments

Comments
 (0)