55
66from __future__ import annotations
77
8- import torch
9- import torch .nn as nn
10-
11- from rsl_rl .modules .actor_critic import ActorCritic
12- from rsl_rl .utils import resolve_nn_activation , unpad_trajectories
8+ from rsl_rl .modules import ActorCritic
9+ from rsl_rl .networks import Memory
10+ from rsl_rl .utils import resolve_nn_activation
1311
1412
1513class ActorCriticRecurrent (ActorCritic ):
@@ -24,7 +22,7 @@ def __init__(
2422 critic_hidden_dims = [256 , 256 , 256 ],
2523 activation = "elu" ,
2624 rnn_type = "lstm" ,
27- rnn_hidden_size = 256 ,
25+ rnn_hidden_dim = 256 ,
2826 rnn_num_layers = 1 ,
2927 init_noise_std = 1.0 ,
3028 ** kwargs ,
@@ -35,8 +33,8 @@ def __init__(
3533 )
3634
3735 super ().__init__ (
38- num_actor_obs = rnn_hidden_size ,
39- num_critic_obs = rnn_hidden_size ,
36+ num_actor_obs = rnn_hidden_dim ,
37+ num_critic_obs = rnn_hidden_dim ,
4038 num_actions = num_actions ,
4139 actor_hidden_dims = actor_hidden_dims ,
4240 critic_hidden_dims = critic_hidden_dims ,
@@ -46,8 +44,8 @@ def __init__(
4644
4745 activation = resolve_nn_activation (activation )
4846
49- self .memory_a = Memory (num_actor_obs , type = rnn_type , num_layers = rnn_num_layers , hidden_size = rnn_hidden_size )
50- self .memory_c = Memory (num_critic_obs , type = rnn_type , num_layers = rnn_num_layers , hidden_size = rnn_hidden_size )
47+ self .memory_a = Memory (num_actor_obs , type = rnn_type , num_layers = rnn_num_layers , hidden_size = rnn_hidden_dim )
48+ self .memory_c = Memory (num_critic_obs , type = rnn_type , num_layers = rnn_num_layers , hidden_size = rnn_hidden_dim )
5149
5250 print (f"Actor RNN: { self .memory_a } " )
5351 print (f"Critic RNN: { self .memory_c } " )
@@ -70,32 +68,3 @@ def evaluate(self, critic_observations, masks=None, hidden_states=None):
7068
7169 def get_hidden_states (self ):
7270 return self .memory_a .hidden_states , self .memory_c .hidden_states
73-
74-
75- class Memory (torch .nn .Module ):
76- def __init__ (self , input_size , type = "lstm" , num_layers = 1 , hidden_size = 256 ):
77- super ().__init__ ()
78- # RNN
79- rnn_cls = nn .GRU if type .lower () == "gru" else nn .LSTM
80- self .rnn = rnn_cls (input_size = input_size , hidden_size = hidden_size , num_layers = num_layers )
81- self .hidden_states = None
82-
83- def forward (self , input , masks = None , hidden_states = None ):
84- batch_mode = masks is not None
85- if batch_mode :
86- # batch mode (policy update): need saved hidden states
87- if hidden_states is None :
88- raise ValueError ("Hidden states not passed to memory module during policy update" )
89- out , _ = self .rnn (input , hidden_states )
90- out = unpad_trajectories (out , masks )
91- else :
92- # inference mode (collection): use hidden states of last step
93- out , self .hidden_states = self .rnn (input .unsqueeze (0 ), self .hidden_states )
94- return out
95-
96- def reset (self , dones = None ):
97- # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
98- if self .hidden_states is None :
99- return
100- for hidden_state in self .hidden_states :
101- hidden_state [..., dones == 1 , :] = 0.0
0 commit comments