Skip to content

Commit c743c0b

Browse files
committed
remake train args
1 parent 3c38ec8 commit c743c0b

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

autointent/_wrappers/embedder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,18 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
147147
output_dir=tmp_dir,
148148
num_train_epochs=config.epoch_num,
149149
per_device_train_batch_size=config.batch_size,
150+
per_device_eval_batch_size=8,
151+
eval_steps=1,
150152
learning_rate=config.learning_rate,
151153
warmup_ratio=config.warmup_ratio,
152-
metric_for_best_model="eval_loss",
153-
greater_is_better=False,
154154
fp16=config.fp16,
155155
bf16=config.bf16,
156156
batch_sampler=BatchSamplers.NO_DUPLICATES,
157+
metric_for_best_model="eval_loss",
158+
load_best_model_at_end=True,
159+
evaluation_strategy = "epoch",
160+
greater_is_better=False,
157161
)
158-
if config.early_stopping:
159-
args.set_training(load_best_model_at_end=True)
160-
args.set_evaluate(strategy="epoch", steps=1)
161162
trainer = SentenceTransformerTrainer(
162163
model=self.embedding_model,
163164
args=args,

0 commit comments

Comments
 (0)