Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class DreamerActor(nn.Module):
Defaults to 5.0.
std_min_val (:obj:`float`, optional): Minimum value of the standard deviation.
Defaults to 1e-4.
device (torch.device, optional): Device to create the module on.
Defaults to None (uses default device).
"""

def __init__(
Expand All @@ -57,13 +59,15 @@ def __init__(
activation_class=nn.ELU,
std_bias=5.0,
std_min_val=1e-4,
device=None,
):
super().__init__()
self.backbone = MLP(
out_features=2 * out_features,
depth=depth,
num_cells=num_cells,
activation_class=activation_class,
device=device,
)
self.backbone.append(
NormalParamExtractor(
Expand All @@ -88,9 +92,13 @@ class ObsEncoder(nn.Module):
channels (int, optional): Number of hidden units in the first layer.
Defaults to 32.
num_layers (int, optional): Depth of the network. Defaults to 4.
in_channels (int, optional): Number of input channels. If None, uses LazyConv2d.
Defaults to None for backward compatibility.
device (torch.device, optional): Device to create the module on.
Defaults to None (uses default device).
"""

def __init__(self, channels=32, num_layers=4, depth=None):
def __init__(self, channels=32, num_layers=4, in_channels=None, depth=None, device=None):
if depth is not None:
warnings.warn(
f"The depth argument in {type(self)} will soon be deprecated and "
Expand All @@ -102,14 +110,19 @@ def __init__(self, channels=32, num_layers=4, depth=None):
if num_layers < 1:
raise RuntimeError("num_layers cannot be smaller than 1.")
super().__init__()
# Use explicit Conv2d if in_channels provided, else LazyConv2d for backward compat
if in_channels is not None:
first_conv = nn.Conv2d(in_channels, channels, 4, stride=2, device=device)
else:
first_conv = nn.LazyConv2d(channels, 4, stride=2, device=device)
layers = [
nn.LazyConv2d(channels, 4, stride=2),
first_conv,
nn.ReLU(),
]
k = 1
for _ in range(1, num_layers):
layers += [
nn.Conv2d(channels * k, channels * (k * 2), 4, stride=2),
nn.Conv2d(channels * k, channels * (k * 2), 4, stride=2, device=device),
nn.ReLU(),
]
k = k * 2
Expand Down Expand Up @@ -140,9 +153,13 @@ class ObsDecoder(nn.Module):
num_layers (int, optional): Depth of the network. Defaults to 4.
kernel_sizes (int or list of int, optional): the kernel_size of each layer.
Defaults to ``[5, 5, 6, 6]`` if num_layers if 4, else ``[5] * num_layers``.
latent_dim (int, optional): Input dimension (state_dim + rnn_hidden_dim).
If None, uses LazyLinear. Defaults to None for backward compatibility.
device (torch.device, optional): Device to create the module on.
Defaults to None (uses default device).
"""

def __init__(self, channels=32, num_layers=4, kernel_sizes=None, depth=None):
def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None, depth=None, device=None):
if depth is not None:
warnings.warn(
f"The depth argument in {type(self)} will soon be deprecated and "
Expand All @@ -155,8 +172,14 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, depth=None):
raise RuntimeError("num_layers cannot be smaller than 1.")

super().__init__()
# Use explicit Linear if latent_dim provided, else LazyLinear for backward compat
linear_out = channels * 8 * 2 * 2
if latent_dim is not None:
first_linear = nn.Linear(latent_dim, linear_out, device=device)
else:
first_linear = nn.LazyLinear(linear_out, device=device)
self.state_to_latent = nn.Sequential(
nn.LazyLinear(channels * 8 * 2 * 2),
first_linear,
nn.ReLU(),
)
if kernel_sizes is None and num_layers == 4:
Expand All @@ -167,23 +190,24 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, depth=None):
kernel_sizes = [kernel_sizes] * num_layers
layers = [
nn.ReLU(),
nn.ConvTranspose2d(channels, 3, kernel_sizes[-1], stride=2),
nn.ConvTranspose2d(channels, 3, kernel_sizes[-1], stride=2, device=device),
]
kernel_sizes = kernel_sizes[:-1]
k = 1
for j in range(1, num_layers):
if j != num_layers - 1:
layers = [
nn.ConvTranspose2d(
channels * k * 2, channels * k, kernel_sizes[-1], stride=2
channels * k * 2, channels * k, kernel_sizes[-1], stride=2, device=device
),
] + layers
kernel_sizes = kernel_sizes[:-1]
k = k * 2
layers = [nn.ReLU()] + layers
else:
# Use explicit ConvTranspose2d - input is always channels * 8 from state_to_latent
layers = [
nn.LazyConvTranspose2d(channels * k, kernel_sizes[-1], stride=2)
nn.ConvTranspose2d(linear_out, channels * k, kernel_sizes[-1], stride=2, device=device)
] + layers

self.decoder = nn.Sequential(*layers)
Expand Down Expand Up @@ -290,6 +314,10 @@ class RSSMPrior(nn.Module):
Defaults to 30.
scale_lb (:obj:`float`, optional): Lower bound of the scale of the state distribution.
Defaults to 0.1.
action_dim (int, optional): Dimension of the action. If provided along with state_dim,
uses explicit Linear instead of LazyLinear. Defaults to None for backward compatibility.
device (torch.device, optional): Device to create the module on.
Defaults to None (uses default device).


"""
Expand All @@ -301,16 +329,23 @@ def __init__(
rnn_hidden_dim=200,
state_dim=30,
scale_lb=0.1,
action_dim=None,
device=None,
):
super().__init__()

# Prior
self.rnn = GRUCell(hidden_dim, rnn_hidden_dim)
self.action_state_projector = nn.Sequential(nn.LazyLinear(hidden_dim), nn.ELU())
# Prior - use explicit Linear if action_dim provided, else LazyLinear
self.rnn = GRUCell(hidden_dim, rnn_hidden_dim, device=device)
if action_dim is not None:
projector_in = state_dim + action_dim
first_linear = nn.Linear(projector_in, hidden_dim, device=device)
else:
first_linear = nn.LazyLinear(hidden_dim, device=device)
self.action_state_projector = nn.Sequential(first_linear, nn.ELU())
self.rnn_to_prior_projector = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Linear(hidden_dim, hidden_dim, device=device),
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
nn.Linear(hidden_dim, 2 * state_dim, device=device),
NormalParamExtractor(
scale_lb=scale_lb,
scale_mapping="softplus",
Expand All @@ -330,7 +365,14 @@ def forward(self, state, belief, action):
belief = belief.unsqueeze(0)
action_state = action_state.unsqueeze(0)
unsqueeze = True
belief = self.rnn(action_state, belief)

# GRUCell can have issues with bfloat16 autocast on some GPU/cuBLAS combinations.
# Run the RNN in full precision to avoid CUBLAS_STATUS_INVALID_VALUE errors.
dtype = action_state.dtype
device_type = action_state.device.type
with torch.amp.autocast(device_type=device_type, enabled=False):
belief = self.rnn(action_state.float(), belief.float() if belief is not None else None)
belief = belief.to(dtype)
if unsqueeze:
belief = belief.squeeze(0)

Expand All @@ -354,15 +396,27 @@ class RSSMPosterior(nn.Module):
Defaults to 30.
scale_lb (:obj:`float`, optional): Lower bound of the scale of the state distribution.
Defaults to 0.1.
rnn_hidden_dim (int, optional): Dimension of the belief/rnn hidden state.
If provided along with obs_embed_dim, uses explicit Linear. Defaults to None.
obs_embed_dim (int, optional): Dimension of the observation embedding.
If provided along with rnn_hidden_dim, uses explicit Linear. Defaults to None.
device (torch.device, optional): Device to create the module on.
Defaults to None (uses default device).

"""

def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1):
def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1, rnn_hidden_dim=None, obs_embed_dim=None, device=None):
super().__init__()
# Use explicit Linear if both dims provided, else LazyLinear for backward compat
if rnn_hidden_dim is not None and obs_embed_dim is not None:
projector_in = rnn_hidden_dim + obs_embed_dim
first_linear = nn.Linear(projector_in, hidden_dim, device=device)
else:
first_linear = nn.LazyLinear(hidden_dim, device=device)
self.obs_rnn_to_post_projector = nn.Sequential(
nn.LazyLinear(hidden_dim),
first_linear,
nn.ELU(),
nn.Linear(hidden_dim, 2 * state_dim),
nn.Linear(hidden_dim, 2 * state_dim, device=device),
NormalParamExtractor(
scale_lb=scale_lb,
scale_mapping="softplus",
Expand All @@ -374,6 +428,5 @@ def forward(self, belief, obs_embedding):
posterior_mean, posterior_std = self.obs_rnn_to_post_projector(
torch.cat([belief, obs_embedding], dim=-1)
)
# post_std = post_std + 0.1
state = posterior_mean + torch.randn_like(posterior_std) * posterior_std
return posterior_mean, posterior_std, state
Loading