Skip to content

Commit 0823b34

Browse files
correct type hint for tuple of multiple tensors
1 parent 93f9ba8 commit 0823b34

File tree

5 files changed

+11
-11
lines changed

5 files changed

+11
-11
lines changed

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def act(
166166
self,
167167
obs: TensorDict,
168168
masks: torch.Tensor | None = None,
169-
hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None,
169+
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
170170
) -> torch.Tensor:
171171
obs = self.get_actor_obs(obs)
172172
obs = self.actor_obs_normalizer(obs)
@@ -187,7 +187,7 @@ def evaluate(
187187
self,
188188
obs: TensorDict,
189189
masks: torch.Tensor | None = None,
190-
hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None,
190+
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
191191
) -> torch.Tensor:
192192
obs = self.get_critic_obs(obs)
193193
obs = self.critic_obs_normalizer(obs)
@@ -207,7 +207,7 @@ def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
207207

208208
def get_hidden_states(
209209
self,
210-
) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, torch.Tensor | tuple[torch.Tensor] | None]:
210+
) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, torch.Tensor | tuple[torch.Tensor, ...] | None]:
211211
return self.memory_a.hidden_states, self.memory_c.hidden_states
212212

213213
def update_normalization(self, obs: TensorDict) -> None:

rsl_rl/modules/student_teacher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
def reset(
9494
self,
9595
dones: torch.Tensor | None = None,
96-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None),
96+
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...] = (None, None),
9797
) -> None:
9898
pass
9999

@@ -150,7 +150,7 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
150150
obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
151151
return torch.cat(obs_list, dim=-1)
152152

153-
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, ...]:
153+
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...]:
154154
return None, None
155155

156156
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
def reset(
113113
self,
114114
dones: torch.Tensor | None = None,
115-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None),
115+
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...] = (None, None),
116116
) -> None:
117117
self.memory_s.reset(dones, hidden_states[0])
118118
if self.teacher_recurrent:
@@ -178,7 +178,7 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
178178

179179
def get_hidden_states(
180180
self,
181-
) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, torch.Tensor | tuple[torch.Tensor] | None]:
181+
) -> tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, torch.Tensor | tuple[torch.Tensor, ...] | None]:
182182
if self.teacher_recurrent:
183183
return self.memory_s.hidden_states, self.memory_t.hidden_states
184184
else:

rsl_rl/networks/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def forward(
2727
self,
2828
input: torch.Tensor,
2929
masks: torch.Tensor | None = None,
30-
hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None,
30+
hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None,
3131
) -> torch.Tensor:
3232
batch_mode = masks is not None
3333
if batch_mode:
@@ -42,7 +42,7 @@ def forward(
4242
return out
4343

4444
def reset(
45-
self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor] | None = None
45+
self, dones: torch.Tensor | None = None, hidden_states: torch.Tensor | tuple[torch.Tensor, ...] | None = None
4646
) -> None:
4747
if dones is None: # Reset hidden states
4848
if hidden_states is None:

rsl_rl/storage/rollout_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self) -> None:
2424
self.actions_log_prob: torch.Tensor
2525
self.action_mean: torch.Tensor | None = None
2626
self.action_sigma: torch.Tensor | None = None
27-
self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None)
27+
self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...] = (None, None)
2828

2929
def clear(self) -> None:
3030
self.__init__()
@@ -102,7 +102,7 @@ def add_transitions(self, transition: Transition) -> None:
102102
# Increment the counter
103103
self.step += 1
104104

105-
def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...]) -> None:
105+
def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor, ...] | None, ...]) -> None:
106106
if hidden_states == (None, None):
107107
return
108108
# Make a tuple out of GRU hidden states to match the LSTM format

0 commit comments

Comments
 (0)