Skip to content

Commit a6665f8

Browse files
add convolution networks for actor critic
1 parent f80d475 commit a6665f8

File tree

4 files changed

+461
-2
lines changed

4 files changed

+461
-2
lines changed

rsl_rl/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
"""Definitions for neural-network components for RL-agents."""
77

88
from .actor_critic import ActorCritic
9+
from .actor_critic_conv2d import ActorCriticConv2d
910
from .actor_critic_recurrent import ActorCriticRecurrent
1011
from .normalizer import EmpiricalNormalization
1112
from .rnd import RandomNetworkDistillation
1213

13-
__all__ = ["ActorCritic", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation"]
14+
__all__ = ["ActorCritic", "ActorCriticConv2d", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation"]
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
import torch
7+
import torch.nn as nn
8+
from torch.distributions import Normal
9+
10+
from rsl_rl.utils import resolve_nn_activation
11+
12+
13+
class ResidualBlock(nn.Module):
14+
def __init__(self, channels):
15+
super().__init__()
16+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
17+
self.bn1 = nn.BatchNorm2d(channels)
18+
self.relu = nn.ReLU(inplace=True)
19+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
20+
self.bn2 = nn.BatchNorm2d(channels)
21+
22+
def forward(self, x):
23+
residual = x
24+
out = self.conv1(x)
25+
out = self.bn1(out)
26+
out = self.relu(out)
27+
out = self.conv2(out)
28+
out = self.bn2(out)
29+
out += residual
30+
out = self.relu(out)
31+
return out
32+
33+
34+
class ConvolutionalNetwork(nn.Module):
35+
def __init__(
36+
self,
37+
proprio_input_dim,
38+
output_dim,
39+
image_input_shape,
40+
conv_layers_params,
41+
hidden_dims,
42+
activation_fn,
43+
conv_linear_output_size,
44+
):
45+
super().__init__()
46+
47+
self.image_input_shape = image_input_shape # (C, H, W)
48+
self.image_obs_size = torch.prod(torch.tensor(self.image_input_shape)).item()
49+
self.proprio_obs_size = proprio_input_dim
50+
self.input_dim = self.proprio_obs_size + self.image_obs_size
51+
self.activation_fn = activation_fn
52+
53+
# Build conv network and get its output size
54+
self.conv_net = self.build_conv_net(conv_layers_params)
55+
with torch.no_grad():
56+
dummy_image = torch.zeros(1, *self.image_input_shape)
57+
conv_output = self.conv_net(dummy_image)
58+
self.image_feature_size = conv_output.view(1, -1).shape[1]
59+
60+
# Build the connection layers between conv net and mlp
61+
self.conv_linear = nn.Linear(self.image_feature_size, conv_linear_output_size)
62+
self.layernorm = nn.LayerNorm(conv_linear_output_size)
63+
64+
# Build the mlp
65+
self.mlp = nn.Sequential(
66+
nn.Linear(self.proprio_obs_size + conv_linear_output_size, hidden_dims[0]),
67+
self.activation_fn,
68+
*[
69+
layer
70+
for dim in zip(hidden_dims[:-1], hidden_dims[1:])
71+
for layer in (nn.Linear(dim[0], dim[1]), self.activation_fn)
72+
],
73+
nn.Linear(hidden_dims[-1], output_dim),
74+
)
75+
76+
# Initialize the weights
77+
self._initialize_weights()
78+
79+
def build_conv_net(self, conv_layers_params):
80+
layers = []
81+
in_channels = self.image_input_shape[0]
82+
for idx, params in enumerate(conv_layers_params[:-1]):
83+
layers.extend([
84+
nn.Conv2d(
85+
in_channels,
86+
params["out_channels"],
87+
kernel_size=params.get("kernel_size", 3),
88+
stride=params.get("stride", 1),
89+
padding=params.get("padding", 0),
90+
),
91+
nn.BatchNorm2d(params["out_channels"]),
92+
nn.ReLU(inplace=True),
93+
ResidualBlock(params["out_channels"]) if idx > 0 else nn.Identity(),
94+
])
95+
in_channels = params["out_channels"]
96+
last_params = conv_layers_params[-1]
97+
layers.append(
98+
nn.Conv2d(
99+
in_channels,
100+
last_params["out_channels"],
101+
kernel_size=last_params.get("kernel_size", 3),
102+
stride=last_params.get("stride", 1),
103+
padding=last_params.get("padding", 0),
104+
)
105+
)
106+
layers.append(nn.BatchNorm2d(last_params["out_channels"]))
107+
return nn.Sequential(*layers)
108+
109+
def _initialize_weights(self):
110+
for m in self.conv_net.modules():
111+
if isinstance(m, nn.Conv2d):
112+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
113+
elif isinstance(m, nn.BatchNorm2d):
114+
nn.init.constant_(m.weight, 1)
115+
nn.init.constant_(m.bias, 0)
116+
117+
nn.init.kaiming_normal_(self.conv_linear.weight, mode="fan_out", nonlinearity="tanh")
118+
nn.init.constant_(self.conv_linear.bias, 0)
119+
nn.init.constant_(self.layernorm.weight, 1.0)
120+
nn.init.constant_(self.layernorm.bias, 0.0)
121+
122+
for layer in self.mlp:
123+
if isinstance(layer, nn.Linear):
124+
nn.init.orthogonal_(layer.weight, gain=0.01)
125+
nn.init.zeros_(layer.bias) if layer.bias is not None else None
126+
127+
def forward(self, observations):
128+
proprio_obs = observations[:, : -self.image_obs_size]
129+
image_obs = observations[:, -self.image_obs_size :]
130+
131+
batch_size = image_obs.size(0)
132+
image = image_obs.view(batch_size, *self.image_input_shape)
133+
134+
conv_features = self.conv_net(image)
135+
flattened_conv_features = conv_features.view(batch_size, -1)
136+
normalized_conv_output = self.layernorm(self.conv_linear(flattened_conv_features))
137+
combined_input = torch.cat([proprio_obs, normalized_conv_output], dim=1)
138+
output = self.mlp(combined_input)
139+
return output
140+
141+
142+
class ActorCriticConv2d(nn.Module):
143+
is_recurrent = False
144+
145+
def __init__(
146+
self,
147+
num_actor_obs,
148+
num_critic_obs,
149+
num_actions,
150+
image_input_shape,
151+
conv_layers_params,
152+
conv_linear_output_size,
153+
actor_hidden_dims,
154+
critic_hidden_dims,
155+
activation="elu",
156+
init_noise_std=1.0,
157+
**kwargs,
158+
):
159+
super().__init__()
160+
161+
self.image_input_shape = image_input_shape # (C, H, W)
162+
self.activation_fn = resolve_nn_activation(activation)
163+
164+
self.actor = ConvolutionalNetwork(
165+
proprio_input_dim=num_actor_obs,
166+
output_dim=num_actions,
167+
image_input_shape=image_input_shape,
168+
conv_layers_params=conv_layers_params,
169+
hidden_dims=actor_hidden_dims,
170+
activation_fn=self.activation_fn,
171+
conv_linear_output_size=conv_linear_output_size,
172+
)
173+
174+
self.critic = ConvolutionalNetwork(
175+
proprio_input_dim=num_critic_obs,
176+
output_dim=1,
177+
image_input_shape=image_input_shape,
178+
conv_layers_params=conv_layers_params,
179+
hidden_dims=critic_hidden_dims,
180+
activation_fn=self.activation_fn,
181+
conv_linear_output_size=conv_linear_output_size,
182+
)
183+
184+
print(f"Modified Actor Network: {self.actor}")
185+
print(f"Modified Critic Network: {self.critic}")
186+
187+
# Action noise
188+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
189+
# Action distribution (populated in update_distribution)
190+
self.distribution = None
191+
# disable args validation for speedup
192+
Normal.set_default_validate_args(False)
193+
194+
def reset(self, dones=None):
195+
pass
196+
197+
def forward(self):
198+
raise NotImplementedError
199+
200+
@property
201+
def action_mean(self):
202+
return self.distribution.mean
203+
204+
@property
205+
def action_std(self):
206+
return self.distribution.stddev
207+
208+
@property
209+
def entropy(self):
210+
return self.distribution.entropy().sum(dim=-1)
211+
212+
def update_distribution(self, observations):
213+
mean = self.actor(observations)
214+
self.distribution = Normal(mean, self.std)
215+
216+
def act(self, observations, **kwargs):
217+
self.update_distribution(observations)
218+
return self.distribution.sample()
219+
220+
def get_actions_log_prob(self, actions):
221+
return self.distribution.log_prob(actions).sum(dim=-1)
222+
223+
def act_inference(self, observations):
224+
actions_mean = self.actor(observations)
225+
return actions_mean
226+
227+
def evaluate(self, critic_observations, **kwargs):
228+
value = self.critic(critic_observations)
229+
return value

rsl_rl/runners/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
"""Implementation of runners for environment-agent interaction."""
77

88
from .on_policy_runner import OnPolicyRunner
9+
from .on_policy_runner_conv2d import OnPolicyRunnerConv2d
910

10-
__all__ = ["OnPolicyRunner"]
11+
__all__ = ["OnPolicyRunner", "OnPolicyRunnerConv2d"]

0 commit comments

Comments
 (0)