Skip to content

Commit a4d108a

Browse files
authored
Adds state-dependent standard deviation for the PPO actor (#112)
* Adding option for state-dependent standard deviation out for the PPO actor. * Fix for missing arg in Unflatten
1 parent bc1c7c4 commit a4d108a

File tree

3 files changed

+80
-29
lines changed

3 files changed

+80
-29
lines changed

rsl_rl/modules/actor_critic.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
activation="elu",
2828
init_noise_std=1.0,
2929
noise_std_type: str = "scalar",
30+
state_dependent_std=False,
3031
**kwargs,
3132
):
3233
if kwargs:
@@ -47,8 +48,12 @@ def __init__(
4748
assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations."
4849
num_critic_obs += obs[obs_group].shape[-1]
4950

51+
self.state_dependent_std = state_dependent_std
5052
# actor
51-
self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
53+
if self.state_dependent_std:
54+
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
55+
else:
56+
self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
5257
# actor observation normalization
5358
self.actor_obs_normalization = actor_obs_normalization
5459
if actor_obs_normalization:
@@ -69,12 +74,21 @@ def __init__(
6974

7075
# Action noise
7176
self.noise_std_type = noise_std_type
72-
if self.noise_std_type == "scalar":
73-
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
74-
elif self.noise_std_type == "log":
75-
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
77+
if self.state_dependent_std:
78+
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
79+
if self.noise_std_type == "scalar":
80+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
81+
elif self.noise_std_type == "log":
82+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)))
83+
else:
84+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
7685
else:
77-
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
86+
if self.noise_std_type == "scalar":
87+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
88+
elif self.noise_std_type == "log":
89+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
90+
else:
91+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
7892

7993
# Action distribution (populated in update_distribution)
8094
self.distribution = None
@@ -100,15 +114,26 @@ def entropy(self):
100114
return self.distribution.entropy().sum(dim=-1)
101115

102116
def update_distribution(self, obs):
103-
# compute mean
104-
mean = self.actor(obs)
105-
# compute standard deviation
106-
if self.noise_std_type == "scalar":
107-
std = self.std.expand_as(mean)
108-
elif self.noise_std_type == "log":
109-
std = torch.exp(self.log_std).expand_as(mean)
117+
if self.state_dependent_std:
118+
# compute mean and standard deviation
119+
mean_and_std = self.actor(obs)
120+
if self.noise_std_type == "scalar":
121+
mean, std = torch.unbind(mean_and_std, dim=-2)
122+
elif self.noise_std_type == "log":
123+
mean, log_std = torch.unbind(mean_and_std, dim=-2)
124+
std = torch.exp(log_std)
125+
else:
126+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
110127
else:
111-
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
128+
# compute mean
129+
mean = self.actor(obs)
130+
# compute standard deviation
131+
if self.noise_std_type == "scalar":
132+
std = self.std.expand_as(mean)
133+
elif self.noise_std_type == "log":
134+
std = torch.exp(self.log_std).expand_as(mean)
135+
else:
136+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
112137
# create distribution
113138
self.distribution = Normal(mean, std)
114139

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
activation="elu",
2929
init_noise_std=1.0,
3030
noise_std_type: str = "scalar",
31+
state_dependent_std=False,
3132
rnn_type="lstm",
3233
rnn_hidden_dim=256,
3334
rnn_num_layers=1,
@@ -58,9 +59,14 @@ def __init__(
5859
assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations."
5960
num_critic_obs += obs[obs_group].shape[-1]
6061

62+
self.state_dependent_std = state_dependent_std
6163
# actor
6264
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
63-
self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)
65+
if self.state_dependent_std:
66+
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
67+
else:
68+
self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)
69+
6470
# actor observation normalization
6571
self.actor_obs_normalization = actor_obs_normalization
6672
if actor_obs_normalization:
@@ -84,12 +90,21 @@ def __init__(
8490

8591
# Action noise
8692
self.noise_std_type = noise_std_type
87-
if self.noise_std_type == "scalar":
88-
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
89-
elif self.noise_std_type == "log":
90-
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
93+
if self.state_dependent_std:
94+
torch.nn.init.zeros_(self.actor[-2].weight[num_actions:])
95+
if self.noise_std_type == "scalar":
96+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std)
97+
elif self.noise_std_type == "log":
98+
torch.nn.init.constant_(self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)))
99+
else:
100+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
91101
else:
92-
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
102+
if self.noise_std_type == "scalar":
103+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
104+
elif self.noise_std_type == "log":
105+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
106+
else:
107+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
93108

94109
# Action distribution (populated in update_distribution)
95110
self.distribution = None
@@ -116,15 +131,26 @@ def forward(self):
116131
raise NotImplementedError
117132

118133
def update_distribution(self, obs):
119-
# compute mean
120-
mean = self.actor(obs)
121-
# compute standard deviation
122-
if self.noise_std_type == "scalar":
123-
std = self.std.expand_as(mean)
124-
elif self.noise_std_type == "log":
125-
std = torch.exp(self.log_std).expand_as(mean)
134+
if self.state_dependent_std:
135+
# compute mean and standard deviation
136+
mean_and_std = self.actor(obs)
137+
if self.noise_std_type == "scalar":
138+
mean, std = torch.unbind(mean_and_std, dim=-2)
139+
elif self.noise_std_type == "log":
140+
mean, log_std = torch.unbind(mean_and_std, dim=-2)
141+
std = torch.exp(log_std)
142+
else:
143+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
126144
else:
127-
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
145+
# compute mean
146+
mean = self.actor(obs)
147+
# compute standard deviation
148+
if self.noise_std_type == "scalar":
149+
std = self.std.expand_as(mean)
150+
elif self.noise_std_type == "log":
151+
std = torch.exp(self.log_std).expand_as(mean)
152+
else:
153+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
128154
# create distribution
129155
self.distribution = Normal(mean, std)
130156

rsl_rl/networks/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
total_out_dim = reduce(lambda x, y: x * y, output_dim)
7373
# add a layer to reshape the output to the desired shape
7474
layers.append(nn.Linear(hidden_dims_processed[-1], total_out_dim))
75-
layers.append(nn.Unflatten(output_dim))
75+
layers.append(nn.Unflatten(dim=-1, unflattened_size=output_dim))
7676

7777
# add last activation function if specified
7878
if last_activation_mod is not None:

0 commit comments

Comments
 (0)