File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -21,17 +21,20 @@ def test_model_updates_after_training(dataset):
2121 freeze = False ,
2222 )
2323
24- train_config = EmbedderFineTuningConfig (epoch_num = 1 , batch_size = 8 )
24+ train_config = EmbedderFineTuningConfig (epoch_num = 3 , batch_size = 8 )
2525 embedder = Embedder (embedder_config )
2626 embedder ._load_model ()
2727
28+ for param in embedder .embedding_model .parameters ():
29+ assert param .requires_grad , "All trainable parameters should have requires_grad=True"
30+
2831 original_weights = [
2932 param .data .detach ().cpu ().numpy ().copy ()
3033 for param in embedder .embedding_model .parameters ()
3134 if param .requires_grad
3235 ]
3336 embedder .train (
34- utterances = data_handler .train_utterances (0 )[:10 ], labels = data_handler .train_labels (0 )[:10 ], config = train_config
37+ utterances = data_handler .train_utterances (0 )[:1000 ], labels = data_handler .train_labels (0 )[:1000 ], config = train_config
3538 )
3639
3740 trained_weights = [
You can’t perform that action at this time.
0 commit comments