36
36
OnlineDTActor ,
37
37
ProbabilisticActor ,
38
38
SafeModule ,
39
+ set_recurrent_mode ,
39
40
TanhDelta ,
40
41
TanhNormal ,
41
42
ValueOperator ,
@@ -729,6 +730,31 @@ def test_errs(self):
729
730
with pytest .raises (KeyError , match = "is_init" ):
730
731
lstm_module (td )
731
732
733
+ @pytest .mark .parametrize ("default_val" , [False , True , None ])
734
+ def test_set_recurrent_mode (self , default_val ):
735
+ lstm_module = LSTMModule (
736
+ input_size = 3 ,
737
+ hidden_size = 12 ,
738
+ batch_first = True ,
739
+ in_keys = ["observation" , "hidden0" , "hidden1" ],
740
+ out_keys = ["intermediate" , ("next" , "hidden0" ), ("next" , "hidden1" )],
741
+ default_recurrent_mode = default_val ,
742
+ )
743
+ assert lstm_module .recurrent_mode is bool (default_val )
744
+ with set_recurrent_mode (True ):
745
+ assert lstm_module .recurrent_mode
746
+ with set_recurrent_mode (False ):
747
+ assert not lstm_module .recurrent_mode
748
+ with set_recurrent_mode ("recurrent" ):
749
+ assert lstm_module .recurrent_mode
750
+ with set_recurrent_mode ("sequential" ):
751
+ assert not lstm_module .recurrent_mode
752
+ assert lstm_module .recurrent_mode
753
+ assert not lstm_module .recurrent_mode
754
+ assert lstm_module .recurrent_mode
755
+ assert lstm_module .recurrent_mode is bool (default_val )
756
+
757
+ @pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
732
758
def test_set_temporal_mode (self ):
733
759
lstm_module = LSTMModule (
734
760
input_size = 3 ,
@@ -754,7 +780,8 @@ def test_python_cudnn(self):
754
780
num_layers = 2 ,
755
781
in_keys = ["observation" , "hidden0" , "hidden1" ],
756
782
out_keys = ["intermediate" , ("next" , "hidden0" ), ("next" , "hidden1" )],
757
- ).set_recurrent_mode (True )
783
+ default_recurrent_mode = True ,
784
+ )
758
785
obs = torch .rand (10 , 20 , 3 )
759
786
760
787
hidden0 = torch .rand (10 , 20 , 2 , 12 )
@@ -1109,6 +1136,31 @@ def test_errs(self):
1109
1136
with pytest .raises (KeyError , match = "is_init" ):
1110
1137
gru_module (td )
1111
1138
1139
+ @pytest .mark .parametrize ("default_val" , [False , True , None ])
1140
+ def test_set_recurrent_mode (self , default_val ):
1141
+ gru_module = GRUModule (
1142
+ input_size = 3 ,
1143
+ hidden_size = 12 ,
1144
+ batch_first = True ,
1145
+ in_keys = ["observation" , "hidden" ],
1146
+ out_keys = ["intermediate" , ("next" , "hidden" )],
1147
+ default_recurrent_mode = default_val ,
1148
+ )
1149
+ assert gru_module .recurrent_mode is bool (default_val )
1150
+ with set_recurrent_mode (True ):
1151
+ assert gru_module .recurrent_mode
1152
+ with set_recurrent_mode (False ):
1153
+ assert not gru_module .recurrent_mode
1154
+ with set_recurrent_mode ("recurrent" ):
1155
+ assert gru_module .recurrent_mode
1156
+ with set_recurrent_mode ("sequential" ):
1157
+ assert not gru_module .recurrent_mode
1158
+ assert gru_module .recurrent_mode
1159
+ assert not gru_module .recurrent_mode
1160
+ assert gru_module .recurrent_mode
1161
+ assert gru_module .recurrent_mode is bool (default_val )
1162
+
1163
+ @pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
1112
1164
def test_set_temporal_mode (self ):
1113
1165
gru_module = GRUModule (
1114
1166
input_size = 3 ,
0 commit comments