|
33 | 33 | from tensordict.nn import TensorDictSequential
|
34 | 34 | from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
|
35 | 35 | from torch import multiprocessing as mp, nn, Tensor
|
36 |
| -from torchrl._utils import _replace_last, prod |
| 36 | +from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env |
37 | 37 |
|
38 | 38 | from torchrl.collectors import MultiSyncDataCollector
|
39 | 39 | from torchrl.data import (
|
@@ -9846,6 +9846,40 @@ def test_added_transforms_are_in_eval_mode():
|
9846 | 9846 |
|
9847 | 9847 |
|
9848 | 9848 | class TestTransformedEnv:
|
| 9849 | + @pytest.mark.filterwarnings("error") |
| 9850 | + def test_nested_transformed_env(self): |
| 9851 | + base_env = ContinuousActionVecMockEnv() |
| 9852 | + t1 = RewardScaling(0, 1) |
| 9853 | + t2 = RewardScaling(0, 2) |
| 9854 | + |
| 9855 | + def test_unwrap(): |
| 9856 | + env = TransformedEnv(TransformedEnv(base_env, t1), t2) |
| 9857 | + assert env.base_env is base_env |
| 9858 | + assert isinstance(env.transform, Compose) |
| 9859 | + children = list(env.transform.transforms.children()) |
| 9860 | + assert len(children) == 2 |
| 9861 | + assert children[0].scale == 1 |
| 9862 | + assert children[1].scale == 2 |
| 9863 | + |
| 9864 | + def test_wrap(auto_unwrap=None): |
| 9865 | + env = TransformedEnv( |
| 9866 | + TransformedEnv(base_env, t1), t2, auto_unwrap=auto_unwrap |
| 9867 | + ) |
| 9868 | + assert env.base_env is not base_env |
| 9869 | + assert isinstance(env.base_env.transform, RewardScaling) |
| 9870 | + assert isinstance(env.transform, RewardScaling) |
| 9871 | + |
| 9872 | + with pytest.warns(FutureWarning): |
| 9873 | + test_unwrap() |
| 9874 | + |
| 9875 | + test_wrap(False) |
| 9876 | + |
| 9877 | + with set_auto_unwrap_transformed_env(True): |
| 9878 | + test_unwrap() |
| 9879 | + |
| 9880 | + with set_auto_unwrap_transformed_env(False): |
| 9881 | + test_wrap() |
| 9882 | + |
9849 | 9883 | def test_attr_error(self):
|
9850 | 9884 | class BuggyTransform(Transform):
|
9851 | 9885 | def transform_observation_spec(
|
@@ -9936,20 +9970,6 @@ def test_allow_done_after_reset(self):
|
9936 | 9970 | assert not t1._allow_done_after_reset
|
9937 | 9971 |
|
9938 | 9972 |
|
9939 |
| -def test_nested_transformed_env(): |
9940 |
| - base_env = ContinuousActionVecMockEnv() |
9941 |
| - t1 = RewardScaling(0, 1) |
9942 |
| - t2 = RewardScaling(0, 2) |
9943 |
| - env = TransformedEnv(TransformedEnv(base_env, t1), t2) |
9944 |
| - |
9945 |
| - assert env.base_env is base_env |
9946 |
| - assert isinstance(env.transform, Compose) |
9947 |
| - children = list(env.transform.transforms.children()) |
9948 |
| - assert len(children) == 2 |
9949 |
| - assert children[0].scale == 1 |
9950 |
| - assert children[1].scale == 2 |
9951 |
| - |
9952 |
| - |
9953 | 9973 | def test_transform_parent():
|
9954 | 9974 | base_env = ContinuousActionVecMockEnv()
|
9955 | 9975 | t1 = RewardScaling(0, 1)
|
|
0 commit comments