Skip to content

Commit 55619e2

Browse files
rename hidden_states and define type
1 parent e935b5a commit 55619e2

File tree

7 files changed

+91
-95
lines changed

7 files changed

+91
-95
lines changed

rsl_rl/algorithms/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def update(self) -> dict[str, float]:
213213
old_actions_log_prob_batch,
214214
old_mu_batch,
215215
old_sigma_batch,
216-
hid_states_batch,
216+
hidden_states_batch,
217217
masks_batch,
218218
) in generator:
219219
num_aug = 1 # Number of augmentations per sample. Starts at 1 for no augmentation.
@@ -244,9 +244,9 @@ def update(self) -> dict[str, float]:
244244

245245
# Recompute actions log prob and entropy for current batch of transitions
246246
# Note: We need to do this because we updated the policy with the new parameters
247-
self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
247+
self.policy.act(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[0])
248248
actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
249-
value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
249+
value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[1])
250250
# Note: We only keep the entropy of the first augmentation (the original one)
251251
mu_batch = self.policy.action_mean[:original_batch_size]
252252
sigma_batch = self.policy.action_std[:original_batch_size]

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.distributions import Normal
1313
from typing import Any, NoReturn
1414

15-
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
15+
from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory
1616

1717

1818
class ActorCriticRecurrent(nn.Module):
@@ -162,15 +162,10 @@ def _update_distribution(self, obs: TensorDict) -> None:
162162
# Create distribution
163163
self.distribution = Normal(mean, std)
164164

165-
def act(
166-
self,
167-
obs: TensorDict,
168-
masks: torch.Tensor | None = None,
169-
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
170-
) -> torch.Tensor:
165+
def act(self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None) -> torch.Tensor:
171166
obs = self.get_actor_obs(obs)
172167
obs = self.actor_obs_normalizer(obs)
173-
out_mem = self.memory_a(obs, masks, hidden_states).squeeze(0)
168+
out_mem = self.memory_a(obs, masks, hidden_state).squeeze(0)
174169
self._update_distribution(out_mem)
175170
return self.distribution.sample()
176171

@@ -184,14 +179,11 @@ def act_inference(self, obs: TensorDict) -> torch.Tensor:
184179
return self.actor(out_mem)
185180

186181
def evaluate(
187-
self,
188-
obs: TensorDict,
189-
masks: torch.Tensor | None = None,
190-
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
182+
self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None
191183
) -> torch.Tensor:
192184
obs = self.get_critic_obs(obs)
193185
obs = self.critic_obs_normalizer(obs)
194-
out_mem = self.memory_c(obs, masks, hidden_states).squeeze(0)
186+
out_mem = self.memory_c(obs, masks, hidden_state).squeeze(0)
195187
return self.critic(out_mem)
196188

197189
def get_actor_obs(self, obs: TensorDict) -> torch.Tensor:
@@ -205,10 +197,8 @@ def get_critic_obs(self, obs: TensorDict) -> torch.Tensor:
205197
def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
206198
return self.distribution.log_prob(actions).sum(dim=-1)
207199

208-
def get_hidden_states(
209-
self,
210-
) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, torch.Tensor | tuple[torch.Tensor, ...] | None]:
211-
return self.memory_a.hidden_states, self.memory_c.hidden_states
200+
def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
201+
return self.memory_a.hidden_state, self.memory_c.hidden_state
212202

213203
def update_normalization(self, obs: TensorDict) -> None:
214204
if self.actor_obs_normalization:

rsl_rl/modules/student_teacher.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.distributions import Normal
1212
from typing import Any, NoReturn
1313

14-
from rsl_rl.networks import MLP, EmpiricalNormalization
14+
from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState
1515

1616

1717
class StudentTeacher(nn.Module):
@@ -91,9 +91,7 @@ def __init__(
9191
Normal.set_default_validate_args(False)
9292

9393
def reset(
94-
self,
95-
dones: torch.Tensor | None = None,
96-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...] = (None, None),
94+
self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None)
9795
) -> None:
9896
pass
9997

@@ -150,7 +148,7 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
150148
obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
151149
return torch.cat(obs_list, dim=-1)
152150

153-
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...]:
151+
def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
154152
return None, None
155153

156154
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.distributions import Normal
1313
from typing import Any, NoReturn
1414

15-
from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
15+
from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory
1616

1717

1818
class StudentTeacherRecurrent(nn.Module):
@@ -110,9 +110,7 @@ def __init__(
110110
Normal.set_default_validate_args(False)
111111

112112
def reset(
113-
self,
114-
dones: torch.Tensor | None = None,
115-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...] = (None, None),
113+
self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None)
116114
) -> None:
117115
self.memory_s.reset(dones, hidden_states[0])
118116
if self.teacher_recurrent:
@@ -176,18 +174,16 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
176174
obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
177175
return torch.cat(obs_list, dim=-1)
178176

179-
def get_hidden_states(
180-
self,
181-
) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, torch.Tensor | tuple[torch.Tensor, ...] | None]:
177+
def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
182178
if self.teacher_recurrent:
183-
return self.memory_s.hidden_states, self.memory_t.hidden_states
179+
return self.memory_s.hidden_state, self.memory_t.hidden_state
184180
else:
185-
return self.memory_s.hidden_states, None
181+
return self.memory_s.hidden_state, None
186182

187183
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
188-
self.memory_s.detach_hidden_states(dones)
184+
self.memory_s.detach_hidden_state(dones)
189185
if self.teacher_recurrent:
190-
self.memory_t.detach_hidden_states(dones)
186+
self.memory_t.detach_hidden_state(dones)
191187

192188
def train(self, mode: bool = True) -> None:
193189
super().train(mode)

rsl_rl/networks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
"""Definitions for components of modules."""
77

8-
from .memory import Memory
8+
from .memory import HiddenState, Memory
99
from .mlp import MLP
1010
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
1111

1212
__all__ = [
1313
"MLP",
1414
"EmpiricalDiscountedVariationNormalization",
1515
"EmpiricalNormalization",
16+
"HiddenState",
1617
"Memory",
1718
]

rsl_rl/networks/memory.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,67 +10,71 @@
1010

1111
from rsl_rl.utils import unpad_trajectories
1212

13+
HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None
14+
"""Type alias for the hidden state of RNNs (GRU/LSTM).
15+
16+
For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state).
17+
"""
18+
1319

1420
class Memory(nn.Module):
1521
"""Memory module for recurrent networks.
1622
17-
This module is used to store the hidden states of the policy. It currently only supports GRU and LSTM.
23+
This module is used to store the hidden state of the policy. It currently supports GRU and LSTM.
1824
"""
1925

2026
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:
2127
super().__init__()
2228
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
2329
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers)
24-
self.hidden_states = None
30+
self.hidden_state = None
2531

2632
def forward(
2733
self,
2834
input: torch.Tensor,
2935
masks: torch.Tensor | None = None,
30-
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
36+
hidden_state: HiddenState = None,
3137
) -> torch.Tensor:
3238
batch_mode = masks is not None
3339
if batch_mode:
3440
# Batch mode needs saved hidden states
35-
if hidden_states is None:
41+
if hidden_state is None:
3642
raise ValueError("Hidden states not passed to memory module during policy update")
37-
out, _ = self.rnn(input, hidden_states)
43+
out, _ = self.rnn(input, hidden_state)
3844
out = unpad_trajectories(out, masks)
3945
else:
40-
# Inference/distillation mode uses hidden states of last step
41-
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
46+
# Inference/distillation mode uses hidden state of last step
47+
out, self.hidden_state = self.rnn(input.unsqueeze(0), self.hidden_state)
4248
return out
4349

44-
def reset(
45-
self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None
46-
) -> None:
47-
if dones is None: # Reset hidden states
48-
if hidden_states is None:
49-
self.hidden_states = None
50+
def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = None) -> None:
51+
if dones is None: # Reset hidden state
52+
if hidden_state is None:
53+
self.hidden_state = None
5054
else:
51-
self.hidden_states = hidden_states
52-
elif self.hidden_states is not None: # Reset hidden states of done environments
53-
if hidden_states is None:
54-
if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM
55-
for hidden_state in self.hidden_states:
55+
self.hidden_state = hidden_state
56+
elif self.hidden_state is not None: # Reset hidden state of done environments
57+
if hidden_state is None:
58+
if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
59+
for hidden_state in self.hidden_state:
5660
hidden_state[..., dones == 1, :] = 0.0
5761
else:
58-
self.hidden_states[..., dones == 1, :] = 0.0
62+
self.hidden_state[..., dones == 1, :] = 0.0
5963
else:
6064
NotImplementedError(
61-
"Resetting hidden states of done environments with custom hidden states is not implemented"
65+
"Resetting the hidden state of done environments with a custom hidden state is not implemented"
6266
)
6367

64-
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
65-
if self.hidden_states is not None:
66-
if dones is None: # Detach all hidden states
67-
if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM
68-
self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states)
68+
def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None:
69+
if self.hidden_state is not None:
70+
if dones is None: # Detach hidden state
71+
if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
72+
self.hidden_state = tuple(hidden_state.detach() for hidden_state in self.hidden_state)
6973
else:
70-
self.hidden_states = self.hidden_states.detach()
71-
else: # Detach hidden states of done environments
72-
if isinstance(self.hidden_states, tuple): # Tuple in case of LSTM
73-
for hidden_state in self.hidden_states:
74+
self.hidden_state = self.hidden_state.detach()
75+
else: # Detach hidden state of done environments
76+
if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
77+
for hidden_state in self.hidden_state:
7478
hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach()
7579
else:
76-
self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach()
80+
self.hidden_state[..., dones == 1, :] = self.hidden_state[..., dones == 1, :].detach()

0 commit comments

Comments
 (0)