Skip to content

Commit 1e161c6

Browse files
committed
the number of epochs and train objects have been increased
1 parent 0739413 commit 1e161c6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/embedder/test_fine_tuning.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,20 @@ def test_model_updates_after_training(dataset):
2121
freeze=False,
2222
)
2323

24-
train_config = EmbedderFineTuningConfig(epoch_num=1, batch_size=8)
24+
train_config = EmbedderFineTuningConfig(epoch_num=3, batch_size=8)
2525
embedder = Embedder(embedder_config)
2626
embedder._load_model()
2727

28+
for param in embedder.embedding_model.parameters():
29+
assert param.requires_grad, "All trainable parameters should have requires_grad=True"
30+
2831
original_weights = [
2932
param.data.detach().cpu().numpy().copy()
3033
for param in embedder.embedding_model.parameters()
3134
if param.requires_grad
3235
]
3336
embedder.train(
34-
utterances=data_handler.train_utterances(0)[:10], labels=data_handler.train_labels(0)[:10], config=train_config
37+
utterances=data_handler.train_utterances(0)[:1000], labels=data_handler.train_labels(0)[:1000], config=train_config
3538
)
3639

3740
trained_weights = [

0 commit comments

Comments
 (0)