|
46 | 46 | from torchrl._utils import _standardize
|
47 | 47 | from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
|
48 | 48 | from torchrl.data.postprocs.postprocs import MultiStep
|
49 |
| -from torchrl.envs import EnvBase |
| 49 | +from torchrl.envs import EnvBase, GymEnv, InitTracker, SerialEnv |
| 50 | +from torchrl.envs.libs.gym import _has_gym |
50 | 51 | from torchrl.envs.model_based.dreamer import DreamerEnv
|
51 | 52 | from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
|
52 | 53 | from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
|
53 | 54 | from torchrl.modules import (
|
54 | 55 | DistributionalQValueActor,
|
| 56 | + GRUModule, |
| 57 | + LSTMModule, |
55 | 58 | OneHotCategorical,
|
56 | 59 | QValueActor,
|
57 | 60 | recurrent_mode,
|
58 | 61 | SafeSequential,
|
| 62 | + set_recurrent_mode, |
59 | 63 | WorldModelWrapper,
|
60 | 64 | )
|
61 | 65 | from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
|
|
146 | 150 | dtype_fixture,
|
147 | 151 | get_available_devices,
|
148 | 152 | get_default_devices,
|
| 153 | + PENDULUM_VERSIONED, |
149 | 154 | )
|
150 | 155 | from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv
|
151 | 156 | else:
|
|
154 | 159 | dtype_fixture,
|
155 | 160 | get_available_devices,
|
156 | 161 | get_default_devices,
|
| 162 | + PENDULUM_VERSIONED, |
157 | 163 | )
|
158 | 164 | from mocking_classes import ContinuousActionConvMockEnv
|
159 | 165 |
|
@@ -13755,6 +13761,79 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
|
13755 | 13761 |
|
13756 | 13762 |
|
13757 | 13763 | class TestValues:
|
| 13764 | + @pytest.mark.skipif(not _has_gym, reason="requires gym") |
| 13765 | + @pytest.mark.parametrize("module", ["lstm", "gru"]) |
| 13766 | + def test_gae_recurrent(self, module): |
| 13767 | + # Checks that shifted=True and False provide the same result in GAE when an LSTM is used |
| 13768 | + env = SerialEnv( |
| 13769 | + 2, |
| 13770 | + [ |
| 13771 | + functools.partial( |
| 13772 | + TransformedEnv, GymEnv(PENDULUM_VERSIONED()), InitTracker() |
| 13773 | + ) |
| 13774 | + for _ in range(2) |
| 13775 | + ], |
| 13776 | + ) |
| 13777 | + env.set_seed(0) |
| 13778 | + torch.manual_seed(0) |
| 13779 | + if module == "lstm": |
| 13780 | + recurrent_module = LSTMModule( |
| 13781 | + input_size=env.observation_spec["observation"].shape[-1], |
| 13782 | + hidden_size=64, |
| 13783 | + in_keys=["observation", "rs_h", "rs_c"], |
| 13784 | + out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")], |
| 13785 | + python_based=True, |
| 13786 | + dropout=0, |
| 13787 | + ) |
| 13788 | + elif module == "gru": |
| 13789 | + recurrent_module = GRUModule( |
| 13790 | + input_size=env.observation_spec["observation"].shape[-1], |
| 13791 | + hidden_size=64, |
| 13792 | + in_keys=["observation", "rs_h"], |
| 13793 | + out_keys=["intermediate", ("next", "rs_h")], |
| 13794 | + python_based=True, |
| 13795 | + dropout=0, |
| 13796 | + ) |
| 13797 | + else: |
| 13798 | + raise NotImplementedError |
| 13799 | + recurrent_module.eval() |
| 13800 | + mlp_value = MLP(num_cells=[64], out_features=1) |
| 13801 | + value_net = Seq( |
| 13802 | + recurrent_module, |
| 13803 | + Mod(mlp_value, in_keys=["intermediate"], out_keys=["state_value"]), |
| 13804 | + ) |
| 13805 | + mlp_policy = MLP(num_cells=[64], out_features=1) |
| 13806 | + policy_net = Seq( |
| 13807 | + recurrent_module, |
| 13808 | + Mod(mlp_policy, in_keys=["intermediate"], out_keys=["action"]), |
| 13809 | + ) |
| 13810 | + env = env.append_transform(recurrent_module.make_tensordict_primer()) |
| 13811 | + vals = env.rollout(1000, policy_net, break_when_any_done=False) |
| 13812 | + value_net(vals.copy()) |
| 13813 | + |
| 13814 | + # Shifted |
| 13815 | + gae_shifted = GAE( |
| 13816 | + gamma=0.9, |
| 13817 | + lmbda=0.99, |
| 13818 | + value_network=value_net, |
| 13819 | + shifted=True, |
| 13820 | + ) |
| 13821 | + with set_recurrent_mode(True): |
| 13822 | + r0 = gae_shifted(vals.copy()) |
| 13823 | + a0 = r0["advantage"] |
| 13824 | + |
| 13825 | + gae = GAE( |
| 13826 | + gamma=0.9, |
| 13827 | + lmbda=0.99, |
| 13828 | + value_network=value_net, |
| 13829 | + shifted=False, |
| 13830 | + deactivate_vmap=True, |
| 13831 | + ) |
| 13832 | + with set_recurrent_mode(True): |
| 13833 | + r1 = gae(vals.copy()) |
| 13834 | + a1 = r1["advantage"] |
| 13835 | + torch.testing.assert_close(a0, a1) |
| 13836 | + |
13758 | 13837 | @pytest.mark.parametrize("device", get_default_devices())
|
13759 | 13838 | @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99])
|
13760 | 13839 | @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99])
|
|
0 commit comments