Skip to content

Commit cf86da0

Browse files
add types
1 parent 4de6f15 commit cf86da0

18 files changed

+317
-242
lines changed

rsl_rl/algorithms/distillation.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import torch.nn as nn
8+
from tensordict import TensorDict
89

910
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1011
from rsl_rl.storage import RolloutStorage
@@ -19,14 +20,14 @@ class Distillation:
1920

2021
def __init__(
2122
self,
22-
policy,
23-
num_learning_epochs=1,
24-
gradient_length=15,
25-
learning_rate=1e-3,
26-
max_grad_norm=None,
27-
loss_type="mse",
28-
optimizer="adam",
29-
device="cpu",
23+
policy: StudentTeacher | StudentTeacherRecurrent,
24+
num_learning_epochs: int = 1,
25+
gradient_length: int = 15,
26+
learning_rate: float = 1e-3,
27+
max_grad_norm: float | None = None,
28+
loss_type: str = "mse",
29+
optimizer: str = "adam",
30+
device: str = "cpu",
3031
# Distributed training parameters
3132
multi_gpu_cfg: dict | None = None,
3233
):
@@ -71,7 +72,14 @@ def __init__(
7172

7273
self.num_updates = 0
7374

74-
def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
75+
def init_storage(
76+
self,
77+
training_type: str,
78+
num_envs: int,
79+
num_transitions_per_env: int,
80+
obs: TensorDict,
81+
actions_shape: tuple[int],
82+
):
7583
# create rollout storage
7684
self.storage = RolloutStorage(
7785
training_type,
@@ -82,15 +90,17 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac
8290
self.device,
8391
)
8492

85-
def act(self, obs):
93+
def act(self, obs: TensorDict) -> torch.Tensor:
8694
# compute the actions
8795
self.transition.actions = self.policy.act(obs).detach()
8896
self.transition.privileged_actions = self.policy.evaluate(obs).detach()
8997
# record the observations
9098
self.transition.observations = obs
9199
return self.transition.actions
92100

93-
def process_env_step(self, obs, rewards, dones, extras):
101+
def process_env_step(
102+
self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
103+
):
94104
# update the normalizers
95105
self.policy.update_normalization(obs)
96106

@@ -102,7 +112,7 @@ def process_env_step(self, obs, rewards, dones, extras):
102112
self.transition.clear()
103113
self.policy.reset(dones)
104114

105-
def update(self):
115+
def update(self) -> dict[str, float]:
106116
self.num_updates += 1
107117
mean_behavior_loss = 0
108118
loss = 0

rsl_rl/algorithms/ppo.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import torch.nn as nn
1010
import torch.optim as optim
1111
from itertools import chain
12+
from tensordict import TensorDict
1213

13-
from rsl_rl.modules import ActorCritic
14+
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
1415
from rsl_rl.modules.rnd import RandomNetworkDistillation
1516
from rsl_rl.storage import RolloutStorage
1617
from rsl_rl.utils import string_to_callable
@@ -19,26 +20,26 @@
1920
class PPO:
2021
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
2122

22-
policy: ActorCritic
23+
policy: ActorCritic | ActorCriticRecurrent
2324
"""The actor critic module."""
2425

2526
def __init__(
2627
self,
27-
policy,
28-
num_learning_epochs=5,
29-
num_mini_batches=4,
30-
clip_param=0.2,
31-
gamma=0.99,
32-
lam=0.95,
33-
value_loss_coef=1.0,
34-
entropy_coef=0.01,
35-
learning_rate=0.001,
36-
max_grad_norm=1.0,
37-
use_clipped_value_loss=True,
38-
schedule="adaptive",
39-
desired_kl=0.01,
40-
device="cpu",
41-
normalize_advantage_per_mini_batch=False,
28+
policy: ActorCritic | ActorCriticRecurrent,
29+
num_learning_epochs: int = 5,
30+
num_mini_batches: int = 4,
31+
clip_param: float = 0.2,
32+
gamma: float = 0.99,
33+
lam: float = 0.95,
34+
value_loss_coef: float = 1.0,
35+
entropy_coef: float = 0.01,
36+
learning_rate: float = 0.001,
37+
max_grad_norm: float = 1.0,
38+
use_clipped_value_loss: bool = True,
39+
schedule: str = "adaptive",
40+
desired_kl: float = 0.01,
41+
device: str = "cpu",
42+
normalize_advantage_per_mini_batch: bool = False,
4243
# RND parameters
4344
rnd_cfg: dict | None = None,
4445
# Symmetry parameters
@@ -115,7 +116,14 @@ def __init__(
115116
self.learning_rate = learning_rate
116117
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
117118

118-
def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
119+
def init_storage(
120+
self,
121+
training_type: str,
122+
num_envs: int,
123+
num_transitions_per_env: int,
124+
obs: TensorDict,
125+
actions_shape: tuple[int] | list[int],
126+
):
119127
# create rollout storage
120128
self.storage = RolloutStorage(
121129
training_type,
@@ -126,7 +134,7 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac
126134
self.device,
127135
)
128136

129-
def act(self, obs):
137+
def act(self, obs: TensorDict) -> torch.Tensor:
130138
if self.policy.is_recurrent:
131139
self.transition.hidden_states = self.policy.get_hidden_states()
132140
# compute the actions and values
@@ -139,7 +147,9 @@ def act(self, obs):
139147
self.transition.observations = obs
140148
return self.transition.actions
141149

142-
def process_env_step(self, obs, rewards, dones, extras):
150+
def process_env_step(
151+
self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
152+
):
143153
# update the normalizers
144154
self.policy.update_normalization(obs)
145155
if self.rnd:
@@ -168,14 +178,14 @@ def process_env_step(self, obs, rewards, dones, extras):
168178
self.transition.clear()
169179
self.policy.reset(dones)
170180

171-
def compute_returns(self, obs):
181+
def compute_returns(self, obs: TensorDict):
172182
# compute value for the last step
173183
last_values = self.policy.evaluate(obs).detach()
174184
self.storage.compute_returns(
175185
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
176186
)
177187

178-
def update(self):
188+
def update(self) -> dict[str, float]:
179189
mean_value_loss = 0
180190
mean_surrogate_loss = 0
181191
mean_entropy = 0

rsl_rl/env/vec_env.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_observations(self) -> TensorDict:
5050
"""Return the current observations.
5151
5252
Returns:
53-
observations (TensorDict): Observations from the environment.
53+
observations: Observations from the environment.
5454
"""
5555
raise NotImplementedError
5656

@@ -59,13 +59,13 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
5959
"""Apply input action to the environment.
6060
6161
Args:
62-
actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions)
62+
actions: Input actions to apply. Shape: (num_envs, num_actions)
6363
6464
Returns:
65-
observations (TensorDict): Observations from the environment.
66-
rewards (torch.Tensor): Rewards from the environment. Shape: (num_envs,)
67-
dones (torch.Tensor): Done flags from the environment. Shape: (num_envs,)
68-
extras (dict): Extra information from the environment.
65+
observations: Observations from the environment.
66+
rewards: Rewards from the environment. Shape: (num_envs,)
67+
dones: Done flags from the environment. Shape: (num_envs,)
68+
extras: Extra information from the environment.
6969
7070
Observations:
7171

rsl_rl/modules/actor_critic.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,26 @@
77

88
import torch
99
import torch.nn as nn
10+
from tensordict import TensorDict
1011
from torch.distributions import Normal
1112

1213
from rsl_rl.networks import MLP, EmpiricalNormalization
1314

1415

1516
class ActorCritic(nn.Module):
16-
is_recurrent = False
17+
is_recurrent: bool = False
1718

1819
def __init__(
1920
self,
20-
obs,
21-
obs_groups,
22-
num_actions,
23-
actor_obs_normalization=False,
24-
critic_obs_normalization=False,
25-
actor_hidden_dims=[256, 256, 256],
26-
critic_hidden_dims=[256, 256, 256],
27-
activation="elu",
28-
init_noise_std=1.0,
21+
obs: TensorDict,
22+
obs_groups: dict[str, list[str]],
23+
num_actions: int,
24+
actor_obs_normalization: bool = False,
25+
critic_obs_normalization: bool = False,
26+
actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
27+
critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
28+
activation: str = "elu",
29+
init_noise_std: float = 1.0,
2930
noise_std_type: str = "scalar",
3031
state_dependent_std=False,
3132
**kwargs,
@@ -96,25 +97,29 @@ def __init__(
9697
# disable args validation for speedup
9798
Normal.set_default_validate_args(False)
9899

99-
def reset(self, dones=None):
100+
def reset(
101+
self,
102+
dones: torch.Tensor | None = None,
103+
hidden_states: tuple[torch.Tensor | tuple[torch.Tensor] | None] = (None, None),
104+
):
100105
pass
101106

102107
def forward(self):
103108
raise NotImplementedError
104109

105110
@property
106-
def action_mean(self):
111+
def action_mean(self) -> torch.Tensor:
107112
return self.distribution.mean
108113

109114
@property
110-
def action_std(self):
115+
def action_std(self) -> torch.Tensor:
111116
return self.distribution.stddev
112117

113118
@property
114-
def entropy(self):
119+
def entropy(self) -> torch.Tensor:
115120
return self.distribution.entropy().sum(dim=-1)
116121

117-
def _update_distribution(self, obs):
122+
def _update_distribution(self, obs: TensorDict):
118123
if self.state_dependent_std:
119124
# compute mean and standard deviation
120125
mean_and_std = self.actor(obs)
@@ -138,50 +143,50 @@ def _update_distribution(self, obs):
138143
# create distribution
139144
self.distribution = Normal(mean, std)
140145

141-
def act(self, obs, **kwargs):
146+
def act(self, obs: TensorDict, **kwargs) -> torch.Tensor:
142147
obs = self.get_actor_obs(obs)
143148
obs = self.actor_obs_normalizer(obs)
144149
self._update_distribution(obs)
145150
return self.distribution.sample()
146151

147-
def act_inference(self, obs):
152+
def act_inference(self, obs: TensorDict) -> torch.Tensor:
148153
obs = self.get_actor_obs(obs)
149154
obs = self.actor_obs_normalizer(obs)
150155
if self.state_dependent_std:
151156
return self.actor(obs)[..., 0, :]
152157
else:
153158
return self.actor(obs)
154159

155-
def evaluate(self, obs, **kwargs):
160+
def evaluate(self, obs: TensorDict, **kwargs) -> torch.Tensor:
156161
obs = self.get_critic_obs(obs)
157162
obs = self.critic_obs_normalizer(obs)
158163
return self.critic(obs)
159164

160-
def get_actor_obs(self, obs):
165+
def get_actor_obs(self, obs: TensorDict) -> torch.Tensor:
161166
obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]]
162167
return torch.cat(obs_list, dim=-1)
163168

164-
def get_critic_obs(self, obs):
169+
def get_critic_obs(self, obs: TensorDict) -> torch.Tensor:
165170
obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]]
166171
return torch.cat(obs_list, dim=-1)
167172

168-
def get_actions_log_prob(self, actions):
173+
def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
169174
return self.distribution.log_prob(actions).sum(dim=-1)
170175

171-
def update_normalization(self, obs):
176+
def update_normalization(self, obs: TensorDict):
172177
if self.actor_obs_normalization:
173178
actor_obs = self.get_actor_obs(obs)
174179
self.actor_obs_normalizer.update(actor_obs)
175180
if self.critic_obs_normalization:
176181
critic_obs = self.get_critic_obs(obs)
177182
self.critic_obs_normalizer.update(critic_obs)
178183

179-
def load_state_dict(self, state_dict, strict=True):
184+
def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
180185
"""Load the parameters of the actor-critic model.
181186
182187
Args:
183-
state_dict (dict): State dictionary of the model.
184-
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
188+
state_dict: State dictionary of the model.
189+
strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this
185190
module's state_dict() function.
186191
187192
Returns:

0 commit comments

Comments
 (0)