Skip to content

Commit 93f9ba8

Browse files
Fix type hints
1 parent aeb326c commit 93f9ba8

File tree

6 files changed

+21
-23
lines changed

6 files changed

+21
-23
lines changed

rsl_rl/algorithms/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
102102

103103
# Create rollout storage
104-
self.storage: RolloutStorage = None # type: ignore
104+
self.storage: RolloutStorage | None = None
105105
self.transition = RolloutStorage.Transition()
106106

107107
# PPO parameters
@@ -362,7 +362,7 @@ def update(self) -> dict[str, float]:
362362
loss.backward()
363363
# Compute the gradients for RND
364364
if self.rnd:
365-
self.rnd_optimizer.zero_grad() # type: ignore
365+
self.rnd_optimizer.zero_grad()
366366
rnd_loss.backward()
367367

368368
# Collect gradients from all GPUs

rsl_rl/modules/actor_critic.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,7 @@ def __init__(
103103
# Disable args validation for speedup
104104
Normal.set_default_validate_args(False)
105105

106-
def reset(
107-
self,
108-
dones: torch.Tensor | None = None,
109-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
110-
) -> None:
106+
def reset(self, dones: torch.Tensor | None = None) -> None:
111107
pass
112108

113109
def forward(self) -> NoReturn:

rsl_rl/modules/student_teacher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
def reset(
9494
self,
9595
dones: torch.Tensor | None = None,
96-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
96+
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None),
9797
) -> None:
9898
pass
9999

@@ -150,7 +150,7 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
150150
obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
151151
return torch.cat(obs_list, dim=-1)
152152

153-
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None]:
153+
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, ...]:
154154
return None, None
155155

156156
def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
def reset(
113113
self,
114114
dones: torch.Tensor | None = None,
115-
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
115+
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None),
116116
) -> None:
117117
self.memory_s.reset(dones, hidden_states[0])
118118
if self.teacher_recurrent:
@@ -176,7 +176,9 @@ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
176176
obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
177177
return torch.cat(obs_list, dim=-1)
178178

179-
def get_hidden_states(self) -> tuple[torch.Tensor | tuple[torch.Tensor] | None]:
179+
def get_hidden_states(
180+
self,
181+
) -> tuple[torch.Tensor | tuple[torch.Tensor] | None, torch.Tensor | tuple[torch.Tensor] | None]:
180182
if self.teacher_recurrent:
181183
return self.memory_s.hidden_states, self.memory_t.hidden_states
182184
else:

rsl_rl/runners/on_policy_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
119119
# Update rewards
120120
if self.alg.rnd:
121121
cur_ereward_sum += rewards
122-
cur_ireward_sum += intrinsic_rewards # type: ignore
122+
cur_ireward_sum += intrinsic_rewards
123123
cur_reward_sum += rewards + intrinsic_rewards
124124
else:
125125
cur_reward_sum += rewards

rsl_rl/storage/rollout_storage.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
class RolloutStorage:
1616
class Transition:
1717
def __init__(self) -> None:
18-
self.observations: TensorDict = None # type: ignore
19-
self.actions: torch.Tensor = None # type: ignore
20-
self.privileged_actions: torch.Tensor = None # type: ignore
21-
self.rewards: torch.Tensor = None # type: ignore
22-
self.dones: torch.Tensor = None # type: ignore
23-
self.values: torch.Tensor = None # type: ignore
24-
self.actions_log_prob: torch.Tensor = None # type: ignore
25-
self.action_mean: torch.Tensor = None # type: ignore
26-
self.action_sigma: torch.Tensor = None # type: ignore
27-
self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None) # type: ignore
18+
self.observations: TensorDict | None = None
19+
self.actions: torch.Tensor | None = None
20+
self.privileged_actions: torch.Tensor | None = None
21+
self.rewards: torch.Tensor | None = None
22+
self.dones: torch.Tensor | None = None
23+
self.values: torch.Tensor | None = None
24+
self.actions_log_prob: torch.Tensor
25+
self.action_mean: torch.Tensor | None = None
26+
self.action_sigma: torch.Tensor | None = None
27+
self.hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...] = (None, None)
2828

2929
def clear(self) -> None:
3030
self.__init__()
@@ -102,7 +102,7 @@ def add_transitions(self, transition: Transition) -> None:
102102
# Increment the counter
103103
self.step += 1
104104

105-
def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None]) -> None:
105+
def _save_hidden_states(self, hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None, ...]) -> None:
106106
if hidden_states == (None, None):
107107
return
108108
# Make a tuple out of GRU hidden states to match the LSTM format

0 commit comments

Comments
 (0)