|
72 | 72 | SafeSequential, |
73 | 73 | WorldModelWrapper, |
74 | 74 | ) |
75 | | -from torchrl.modules.distributions.continuous import ( |
76 | | - NormalParamWrapper, |
77 | | - TanhDelta, |
78 | | - TanhNormal, |
79 | | -) |
| 75 | +from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal |
80 | 76 | from torchrl.modules.models.model_based import ( |
81 | 77 | DreamerActor, |
82 | 78 | ObsDecoder, |
@@ -3462,7 +3458,7 @@ def _create_mock_actor( |
3462 | 3458 | action_spec = BoundedTensorSpec( |
3463 | 3459 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
3464 | 3460 | ) |
3465 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 3461 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
3466 | 3462 | module = TensorDictModule( |
3467 | 3463 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
3468 | 3464 | ) |
@@ -4372,7 +4368,7 @@ def _create_mock_actor( |
4372 | 4368 | ): |
4373 | 4369 | # Actor |
4374 | 4370 | action_spec = OneHotDiscreteTensorSpec(action_dim) |
4375 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 4371 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
4376 | 4372 | module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) |
4377 | 4373 | actor = ProbabilisticActor( |
4378 | 4374 | spec=action_spec, |
@@ -4960,7 +4956,7 @@ def _create_mock_actor( |
4960 | 4956 | action_spec = BoundedTensorSpec( |
4961 | 4957 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
4962 | 4958 | ) |
4963 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 4959 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
4964 | 4960 | module = TensorDictModule( |
4965 | 4961 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
4966 | 4962 | ) |
@@ -5655,7 +5651,7 @@ def _create_mock_actor( |
5655 | 5651 | action_spec = BoundedTensorSpec( |
5656 | 5652 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
5657 | 5653 | ) |
5658 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 5654 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
5659 | 5655 | module = TensorDictModule( |
5660 | 5656 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
5661 | 5657 | ) |
@@ -5763,7 +5759,9 @@ def forward(self, obs): |
5763 | 5759 | class ActorClass(nn.Module): |
5764 | 5760 | def __init__(self): |
5765 | 5761 | super().__init__() |
5766 | | - self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim)) |
| 5762 | + self.linear = nn.Sequential( |
| 5763 | + nn.Linear(hidden_dim, 2 * action_dim), NormalParamExtractor() |
| 5764 | + ) |
5767 | 5765 |
|
5768 | 5766 | def forward(self, hidden): |
5769 | 5767 | return self.linear(hidden) |
@@ -6598,7 +6596,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): |
6598 | 6596 | action_spec = BoundedTensorSpec( |
6599 | 6597 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
6600 | 6598 | ) |
6601 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 6599 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
6602 | 6600 | module = TensorDictModule( |
6603 | 6601 | net, in_keys=["observation"], out_keys=["loc", "scale"] |
6604 | 6602 | ) |
@@ -7556,7 +7554,7 @@ def _create_mock_actor( |
7556 | 7554 | action_spec = BoundedTensorSpec( |
7557 | 7555 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
7558 | 7556 | ) |
7559 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 7557 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
7560 | 7558 | module = TensorDictModule( |
7561 | 7559 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
7562 | 7560 | ) |
@@ -7593,8 +7591,8 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu |
7593 | 7591 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
7594 | 7592 | ) |
7595 | 7593 | base_layer = nn.Linear(obs_dim, 5) |
7596 | | - net = NormalParamWrapper( |
7597 | | - nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) |
| 7594 | + net = nn.Sequential( |
| 7595 | + base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() |
7598 | 7596 | ) |
7599 | 7597 | module = TensorDictModule( |
7600 | 7598 | net, in_keys=["observation"], out_keys=["loc", "scale"] |
@@ -8447,7 +8445,7 @@ def _create_mock_actor( |
8447 | 8445 | action_spec = BoundedTensorSpec( |
8448 | 8446 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
8449 | 8447 | ) |
8450 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 8448 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
8451 | 8449 | module = TensorDictModule( |
8452 | 8450 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
8453 | 8451 | ) |
@@ -9144,7 +9142,7 @@ def test_reinforce_value_net( |
9144 | 9142 | batch = 4 |
9145 | 9143 | gamma = 0.9 |
9146 | 9144 | value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) |
9147 | | - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) |
| 9145 | + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) |
9148 | 9146 | module = TensorDictModule( |
9149 | 9147 | net, in_keys=["observation"], out_keys=["loc", "scale"] |
9150 | 9148 | ) |
@@ -9254,7 +9252,7 @@ def test_reinforce_tensordict_keys(self, td_est): |
9254 | 9252 | n_obs = 3 |
9255 | 9253 | n_act = 5 |
9256 | 9254 | value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) |
9257 | | - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) |
| 9255 | + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) |
9258 | 9256 | module = TensorDictModule( |
9259 | 9257 | net, in_keys=["observation"], out_keys=["loc", "scale"] |
9260 | 9258 | ) |
@@ -9448,7 +9446,7 @@ def test_reinforce_notensordict( |
9448 | 9446 | n_act = 5 |
9449 | 9447 | batch = 4 |
9450 | 9448 | value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key]) |
9451 | | - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) |
| 9449 | + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) |
9452 | 9450 | module = TensorDictModule( |
9453 | 9451 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
9454 | 9452 | ) |
@@ -10054,7 +10052,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): |
10054 | 10052 | action_spec = BoundedTensorSpec( |
10055 | 10053 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
10056 | 10054 | ) |
10057 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 10055 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
10058 | 10056 | module = TensorDictModule( |
10059 | 10057 | net, in_keys=["observation"], out_keys=["loc", "scale"] |
10060 | 10058 | ) |
@@ -10286,7 +10284,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): |
10286 | 10284 | action_spec = BoundedTensorSpec( |
10287 | 10285 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
10288 | 10286 | ) |
10289 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 10287 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
10290 | 10288 | module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"]) |
10291 | 10289 | actor = ProbabilisticActor( |
10292 | 10290 | module=module, |
@@ -10479,7 +10477,7 @@ def _create_mock_actor( |
10479 | 10477 | action_spec = BoundedTensorSpec( |
10480 | 10478 | -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) |
10481 | 10479 | ) |
10482 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 10480 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
10483 | 10481 | module = TensorDictModule( |
10484 | 10482 | net, in_keys=[observation_key], out_keys=["loc", "scale"] |
10485 | 10483 | ) |
@@ -11288,7 +11286,7 @@ def _create_mock_actor( |
11288 | 11286 | ): |
11289 | 11287 | # Actor |
11290 | 11288 | action_spec = OneHotDiscreteTensorSpec(action_dim) |
11291 | | - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) |
| 11289 | + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) |
11292 | 11290 | module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) |
11293 | 11291 | actor = ProbabilisticActor( |
11294 | 11292 | spec=action_spec, |
@@ -13989,7 +13987,7 @@ def test_shared_params(dest, expected_dtype, expected_device): |
13989 | 13987 | out_keys=["hidden"], |
13990 | 13988 | ) |
13991 | 13989 | module_action = TensorDictModule( |
13992 | | - NormalParamWrapper(torch.nn.Linear(4, 8)), |
| 13990 | + nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()), |
13993 | 13991 | in_keys=["hidden"], |
13994 | 13992 | out_keys=["loc", "scale"], |
13995 | 13993 | ) |
|
0 commit comments