Skip to content

Commit 474e837

Browse files
author
Vincent Moens
authored
[Refactor] Deprecate NormalParamWrapper (#2308)
1 parent 94abb50 commit 474e837

File tree

20 files changed

+122
-106
lines changed

20 files changed

+122
-106
lines changed

sota-implementations/redq/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
ActorCriticOperator,
5858
ActorValueOperator,
5959
NoisyLinear,
60-
NormalParamWrapper,
60+
NormalParamExtractor,
6161
SafeModule,
6262
SafeSequential,
6363
)
@@ -483,10 +483,12 @@ def make_redq_model(
483483
}
484484

485485
if not gSDE:
486-
actor_net = NormalParamWrapper(
486+
actor_net = nn.Sequential(
487487
actor_net,
488-
scale_mapping=f"biased_softplus_{default_policy_scale}",
489-
scale_lb=cfg.network.scale_lb,
488+
NormalParamExtractor(
489+
scale_mapping=f"biased_softplus_{default_policy_scale}",
490+
scale_lb=cfg.network.scale_lb,
491+
),
490492
)
491493
actor_module = SafeModule(
492494
actor_net,

test/test_cost.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@
7272
SafeSequential,
7373
WorldModelWrapper,
7474
)
75-
from torchrl.modules.distributions.continuous import (
76-
NormalParamWrapper,
77-
TanhDelta,
78-
TanhNormal,
79-
)
75+
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
8076
from torchrl.modules.models.model_based import (
8177
DreamerActor,
8278
ObsDecoder,
@@ -3462,7 +3458,7 @@ def _create_mock_actor(
34623458
action_spec = BoundedTensorSpec(
34633459
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
34643460
)
3465-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
3461+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
34663462
module = TensorDictModule(
34673463
net, in_keys=[observation_key], out_keys=["loc", "scale"]
34683464
)
@@ -4372,7 +4368,7 @@ def _create_mock_actor(
43724368
):
43734369
# Actor
43744370
action_spec = OneHotDiscreteTensorSpec(action_dim)
4375-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
4371+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
43764372
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
43774373
actor = ProbabilisticActor(
43784374
spec=action_spec,
@@ -4960,7 +4956,7 @@ def _create_mock_actor(
49604956
action_spec = BoundedTensorSpec(
49614957
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
49624958
)
4963-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
4959+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
49644960
module = TensorDictModule(
49654961
net, in_keys=[observation_key], out_keys=["loc", "scale"]
49664962
)
@@ -5655,7 +5651,7 @@ def _create_mock_actor(
56555651
action_spec = BoundedTensorSpec(
56565652
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
56575653
)
5658-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
5654+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
56595655
module = TensorDictModule(
56605656
net, in_keys=[observation_key], out_keys=["loc", "scale"]
56615657
)
@@ -5763,7 +5759,9 @@ def forward(self, obs):
57635759
class ActorClass(nn.Module):
57645760
def __init__(self):
57655761
super().__init__()
5766-
self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim))
5762+
self.linear = nn.Sequential(
5763+
nn.Linear(hidden_dim, 2 * action_dim), NormalParamExtractor()
5764+
)
57675765

57685766
def forward(self, hidden):
57695767
return self.linear(hidden)
@@ -6598,7 +6596,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
65986596
action_spec = BoundedTensorSpec(
65996597
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
66006598
)
6601-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
6599+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
66026600
module = TensorDictModule(
66036601
net, in_keys=["observation"], out_keys=["loc", "scale"]
66046602
)
@@ -7556,7 +7554,7 @@ def _create_mock_actor(
75567554
action_spec = BoundedTensorSpec(
75577555
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
75587556
)
7559-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
7557+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
75607558
module = TensorDictModule(
75617559
net, in_keys=[observation_key], out_keys=["loc", "scale"]
75627560
)
@@ -7593,8 +7591,8 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu
75937591
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
75947592
)
75957593
base_layer = nn.Linear(obs_dim, 5)
7596-
net = NormalParamWrapper(
7597-
nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim))
7594+
net = nn.Sequential(
7595+
base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor()
75987596
)
75997597
module = TensorDictModule(
76007598
net, in_keys=["observation"], out_keys=["loc", "scale"]
@@ -8447,7 +8445,7 @@ def _create_mock_actor(
84478445
action_spec = BoundedTensorSpec(
84488446
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
84498447
)
8450-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
8448+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
84518449
module = TensorDictModule(
84528450
net, in_keys=[observation_key], out_keys=["loc", "scale"]
84538451
)
@@ -9144,7 +9142,7 @@ def test_reinforce_value_net(
91449142
batch = 4
91459143
gamma = 0.9
91469144
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
9147-
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
9145+
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
91489146
module = TensorDictModule(
91499147
net, in_keys=["observation"], out_keys=["loc", "scale"]
91509148
)
@@ -9254,7 +9252,7 @@ def test_reinforce_tensordict_keys(self, td_est):
92549252
n_obs = 3
92559253
n_act = 5
92569254
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
9257-
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
9255+
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
92589256
module = TensorDictModule(
92599257
net, in_keys=["observation"], out_keys=["loc", "scale"]
92609258
)
@@ -9448,7 +9446,7 @@ def test_reinforce_notensordict(
94489446
n_act = 5
94499447
batch = 4
94509448
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key])
9451-
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
9449+
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
94529450
module = TensorDictModule(
94539451
net, in_keys=[observation_key], out_keys=["loc", "scale"]
94549452
)
@@ -10054,7 +10052,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
1005410052
action_spec = BoundedTensorSpec(
1005510053
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
1005610054
)
10057-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
10055+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
1005810056
module = TensorDictModule(
1005910057
net, in_keys=["observation"], out_keys=["loc", "scale"]
1006010058
)
@@ -10286,7 +10284,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
1028610284
action_spec = BoundedTensorSpec(
1028710285
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
1028810286
)
10289-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
10287+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
1029010288
module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"])
1029110289
actor = ProbabilisticActor(
1029210290
module=module,
@@ -10479,7 +10477,7 @@ def _create_mock_actor(
1047910477
action_spec = BoundedTensorSpec(
1048010478
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
1048110479
)
10482-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
10480+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
1048310481
module = TensorDictModule(
1048410482
net, in_keys=[observation_key], out_keys=["loc", "scale"]
1048510483
)
@@ -11288,7 +11286,7 @@ def _create_mock_actor(
1128811286
):
1128911287
# Actor
1129011288
action_spec = OneHotDiscreteTensorSpec(action_dim)
11291-
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
11289+
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
1129211290
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
1129311291
actor = ProbabilisticActor(
1129411292
spec=action_spec,
@@ -13989,7 +13987,7 @@ def test_shared_params(dest, expected_dtype, expected_device):
1398913987
out_keys=["hidden"],
1399013988
)
1399113989
module_action = TensorDictModule(
13992-
NormalParamWrapper(torch.nn.Linear(4, 8)),
13990+
nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()),
1399313991
in_keys=["hidden"],
1399413992
out_keys=["loc", "scale"],
1399513993
)

test/test_exploration.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
3232
from torchrl.envs.utils import set_exploration_type
3333
from torchrl.modules import SafeModule, SafeSequential
34-
from torchrl.modules.distributions import TanhNormal
35-
from torchrl.modules.distributions.continuous import (
34+
from torchrl.modules.distributions import (
3635
IndependentNormal,
37-
NormalParamWrapper,
36+
NormalParamExtractor,
37+
TanhNormal,
3838
)
3939
from torchrl.modules.models.exploration import LazygSDEModule
4040
from torchrl.modules.tensordict_module.actors import (
@@ -236,7 +236,9 @@ def test_ou(
236236
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
237237
):
238238
torch.manual_seed(seed)
239-
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
239+
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
240+
device
241+
)
240242
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
241243
action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,))
242244
policy = ProbabilisticActor(
@@ -308,7 +310,9 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0
308310
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
309311
d_act = action_spec.shape[-1]
310312
if probabilistic:
311-
net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device)
313+
net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(
314+
device
315+
)
312316
module = SafeModule(
313317
net,
314318
in_keys=["observation"],
@@ -449,7 +453,9 @@ def test_additivegaussian_sd(
449453
if interface == "module":
450454
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
451455
else:
452-
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
456+
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
457+
device
458+
)
453459
module = SafeModule(
454460
net,
455461
in_keys=["observation"],
@@ -531,7 +537,9 @@ def test_additivegaussian(
531537
pytest.skip("module raises an error if given spec=None")
532538

533539
torch.manual_seed(seed)
534-
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
540+
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
541+
device
542+
)
535543
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
536544
action_spec = BoundedTensorSpec(
537545
-torch.ones(d_act, device=device),
@@ -593,7 +601,7 @@ def test_collector(self, device, parallel_spec, interface, seed=0):
593601
else:
594602
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
595603
d_act = action_spec.shape[-1]
596-
net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device)
604+
net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(device)
597605
module = SafeModule(
598606
net,
599607
in_keys=["observation"],
@@ -658,7 +666,7 @@ def test_gsde(
658666
else:
659667
in_keys = ["observation"]
660668
model = torch.nn.LazyLinear(action_dim * 2, device=device)
661-
wrapper = NormalParamWrapper(model)
669+
wrapper = nn.Sequential(model, NormalParamExtractor())
662670
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
663671
distribution_class = TanhNormal
664672
distribution_kwargs = {"low": -bound, "high": bound}

test/test_tensordictmodules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
LSTMModule,
3535
MLP,
3636
MultiStepActorWrapper,
37-
NormalParamWrapper,
37+
NormalParamExtractor,
3838
OnlineDTActor,
3939
ProbabilisticActor,
4040
SafeModule,
@@ -201,7 +201,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)
201201

202202
in_keys = ["in"]
203203
net = SafeModule(
204-
module=NormalParamWrapper(net),
204+
module=nn.Sequential(net, NormalParamExtractor()),
205205
spec=None,
206206
in_keys=in_keys,
207207
out_keys=out_keys,
@@ -363,7 +363,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy):
363363
net1 = nn.Linear(3, 4)
364364
dummy_net = nn.Linear(4, 4)
365365
net2 = nn.Linear(4, 4 * param_multiplier)
366-
net2 = NormalParamWrapper(net2)
366+
net2 = nn.Sequential(net2, NormalParamExtractor())
367367

368368
if spec_type is None:
369369
spec = None
@@ -474,11 +474,11 @@ def test_sequential_partial(self, stack):
474474
net1 = nn.Linear(3, 4)
475475

476476
net2 = nn.Linear(4, 4 * param_multiplier)
477-
net2 = NormalParamWrapper(net2)
477+
net2 = nn.Sequential(net2, NormalParamExtractor())
478478
net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"])
479479

480480
net3 = nn.Linear(4, 4 * param_multiplier)
481-
net3 = NormalParamWrapper(net3)
481+
net3 = nn.Sequential(net3, NormalParamExtractor())
482482
net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"])
483483

484484
spec = BoundedTensorSpec(-0.1, 0.1, 4)

torchrl/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
IndependentNormal,
1212
MaskedCategorical,
1313
MaskedOneHotCategorical,
14+
NormalParamExtractor,
1415
NormalParamWrapper,
1516
OneHotCategorical,
1617
ReparamGradientStrategy,

torchrl/modules/distributions/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from tensordict.nn import NormalParamExtractor
7+
68
from .continuous import (
7-
__all__ as _all_continuous,
89
Delta,
910
IndependentNormal,
1011
NormalParamWrapper,
@@ -13,14 +14,22 @@
1314
TruncatedNormal,
1415
)
1516
from .discrete import (
16-
__all__ as _all_discrete,
1717
MaskedCategorical,
1818
MaskedOneHotCategorical,
1919
OneHotCategorical,
2020
ReparamGradientStrategy,
2121
)
2222

2323
distributions_maps = {
24-
distribution_class.lower(): eval(distribution_class)
25-
for distribution_class in _all_continuous + _all_discrete
24+
str(dist).lower(): dist
25+
for dist in (
26+
Delta,
27+
IndependentNormal,
28+
TanhDelta,
29+
TanhNormal,
30+
TruncatedNormal,
31+
MaskedCategorical,
32+
MaskedOneHotCategorical,
33+
OneHotCategorical,
34+
)
2635
}

torchrl/modules/distributions/continuous.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,6 @@
2525
)
2626
from torchrl.modules.utils import mappings
2727

28-
__all__ = [
29-
"NormalParamWrapper",
30-
"TanhNormal",
31-
"Delta",
32-
"TanhDelta",
33-
"TruncatedNormal",
34-
"IndependentNormal",
35-
]
36-
3728
# speeds up distribution construction
3829
D.Distribution.set_default_validate_args(False)
3930

@@ -153,6 +144,10 @@ def __init__(
153144
scale_mapping: str = "biased_softplus_1.0",
154145
scale_lb: Number = 1e-4,
155146
) -> None:
147+
warnings.warn(
148+
"The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.",
149+
category=DeprecationWarning,
150+
)
156151
super().__init__()
157152
self.operator = operator
158153
self.scale_mapping = scale_mapping
@@ -759,7 +754,10 @@ def mean(self) -> torch.Tensor:
759754
raise AttributeError("TanhDelta mean has not analytical form.")
760755

761756

762-
def uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
757+
def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
763758
if size is None:
764759
size = torch.Size([])
765760
return torch.randn_like(dist.sample(size))
761+
762+
763+
uniform_sample_delta = _uniform_sample_delta

0 commit comments

Comments
 (0)