Skip to content

Commit 2c9c0b3

Browse files
author
Federico Belotti
authored
Fix/bernoulli (#186)
* TF-like Bernoulli mode * pre-commit * Default dmc config
1 parent 9cff9a7 commit 2c9c0b3

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

sheeprl/algos/dreamer_v3/dreamer_v3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from lightning.fabric import Fabric
1818
from lightning.fabric.wrappers import _FabricModule
1919
from torch import Tensor
20-
from torch.distributions import Bernoulli, Distribution, Independent
20+
from torch.distributions import Distribution, Independent
2121
from torch.optim import Optimizer
2222
from torchmetrics import SumMetric
2323

@@ -27,6 +27,7 @@
2727
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
2828
from sheeprl.envs.wrappers import RestartOnException
2929
from sheeprl.utils.distribution import (
30+
BernoulliSafeMode,
3031
MSEDistribution,
3132
OneHotCategoricalValidateArgs,
3233
SymlogDistribution,
@@ -145,7 +146,7 @@ def train(
145146

146147
# Compute the distribution over the terminal steps, if required
147148
pc = Independent(
148-
Bernoulli(logits=world_model.continue_model(latent_states), validate_args=validate_args),
149+
BernoulliSafeMode(logits=world_model.continue_model(latent_states), validate_args=validate_args),
149150
1,
150151
validate_args=validate_args,
151152
)
@@ -229,7 +230,7 @@ def train(
229230
predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean
230231
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
231232
continues = Independent(
232-
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
233+
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
233234
1,
234235
validate_args=validate_args,
235236
).mode

sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
1313
from omegaconf import DictConfig
1414
from torch import Tensor, nn
15-
from torch.distributions import Bernoulli, Distribution, Independent
15+
from torch.distributions import Distribution, Independent
1616
from torchmetrics import SumMetric
1717

1818
from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel
@@ -21,6 +21,7 @@
2121
from sheeprl.algos.p2e_dv3.agent import build_agent
2222
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
2323
from sheeprl.utils.distribution import (
24+
BernoulliSafeMode,
2425
MSEDistribution,
2526
OneHotCategoricalValidateArgs,
2627
SymlogDistribution,
@@ -161,7 +162,7 @@ def train(
161162

162163
# Compute the distribution over the terminal steps, if required
163164
pc = Independent(
164-
Bernoulli(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
165+
BernoulliSafeMode(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
165166
1,
166167
validate_args=validate_args,
167168
)
@@ -268,7 +269,7 @@ def train(
268269
# Predict values and continues
269270
predicted_values = TwoHotEncodingDistribution(critic["module"](imagined_trajectories), dims=1).mean
270271
continues = Independent(
271-
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
272+
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
272273
1,
273274
validate_args=validate_args,
274275
).mode
@@ -412,7 +413,7 @@ def train(
412413
predicted_values = TwoHotEncodingDistribution(critic_task(imagined_trajectories), dims=1).mean
413414
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
414415
continues = Independent(
415-
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
416+
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
416417
1,
417418
validate_args=validate_args,
418419
).mode

sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ seed: 5
1010

1111
# Environment
1212
env:
13-
num_envs: 1
14-
max_episode_steps: 1000
13+
num_envs: 4
14+
max_episode_steps: -1
1515
id: walker_walk
1616
wrapper:
17-
from_vectors: True
17+
from_vectors: False
1818
from_pixels: True
1919

2020
# Checkpoint
@@ -34,9 +34,8 @@ algo:
3434
encoder:
3535
- rgb
3636
mlp_keys:
37-
encoder:
38-
- state
39-
learning_starts: 8000
37+
encoder: []
38+
learning_starts: 1024
4039
train_every: 2
4140
dense_units: 512
4241
mlp_layers: 2

sheeprl/utils/distribution.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99
from torch import Tensor
10-
from torch.distributions import Categorical, Distribution, constraints
10+
from torch.distributions import Bernoulli, Categorical, Distribution, constraints
1111
from torch.distributions.kl import _kl_categorical_categorical, register_kl
1212
from torch.distributions.utils import broadcast_all
1313

@@ -402,3 +402,13 @@ def rsample(self, sample_shape=torch.Size()):
402402
@register_kl(OneHotCategoricalValidateArgs, OneHotCategoricalValidateArgs)
403403
def _kl_onehotcategoricalvalidateargs_onehotcategoricalvalidateargs(p, q):
404404
return _kl_categorical_categorical(p._categorical, q._categorical)
405+
406+
407+
class BernoulliSafeMode(Bernoulli):
408+
def __init__(self, probs=None, logits=None, validate_args=None):
409+
super().__init__(probs, logits, validate_args)
410+
411+
@property
412+
def mode(self):
413+
mode = (self.probs > 0.5).to(self.probs)
414+
return mode

0 commit comments

Comments
 (0)