1-
21import numpy as np
32import pytest
43
@@ -11,11 +10,7 @@ def test_rnn_prediction(dataset):
1110 """Test that the RNN model can fit and make predictions."""
1211 data_handler = DataHandler (dataset )
1312
14- scorer = RNNScorer (
15- rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ),
16- num_train_epochs = 1 ,
17- batch_size = 8
18- )
13+ scorer = RNNScorer (rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ), num_train_epochs = 1 , batch_size = 8 )
1914
2015 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
2116
@@ -52,11 +47,7 @@ def test_rnn_cache_clearing(dataset):
5247 """Test that the RNN model properly handles cache clearing."""
5348 data_handler = DataHandler (dataset )
5449
55- scorer = RNNScorer (
56- rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ),
57- num_train_epochs = 1 ,
58- batch_size = 8
59- )
50+ scorer = RNNScorer (rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ), num_train_epochs = 1 , batch_size = 8 )
6051
6152 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
6253
@@ -82,9 +73,7 @@ def test_rnn_device(dataset):
8273
8374 # Force CPU
8475 scorer_cpu = RNNScorer (
85- rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 , device = "cpu" ),
86- num_train_epochs = 1 ,
87- batch_size = 8
76+ rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 , device = "cpu" ), num_train_epochs = 1 , batch_size = 8
8877 )
8978
9079 scorer_cpu .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
@@ -97,9 +86,7 @@ def test_rnn_device(dataset):
9786
9887 # Test with default device
9988 scorer_default = RNNScorer (
100- rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ),
101- num_train_epochs = 1 ,
102- batch_size = 8
89+ rnn_config = RNNConfig (embed_dim = 64 , hidden_dim = 128 , n_layers = 1 ), num_train_epochs = 1 , batch_size = 8
10390 )
10491
10592 scorer_default .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
0 commit comments