3434 MultiKeyCountingEnvPolicy ,
3535 NestedCountingEnv ,
3636)
37- from tensordict .nn import TensorDictModule
37+ from tensordict .nn import TensorDictModule , TensorDictSequential
3838from tensordict .tensordict import assert_allclose_td , TensorDict
3939
4040from torch import nn
@@ -939,17 +939,22 @@ def create_env():
939939 [MultiSyncDataCollector , MultiaSyncDataCollector , SyncDataCollector ],
940940)
941941@pytest .mark .parametrize ("exclude" , [True , False ])
942- def test_excluded_keys (collector_class , exclude ):
942+ @pytest .mark .parametrize ("out_key" , ["_dummy" , ("out" , "_dummy" ), ("_out" , "dummy" )])
943+ def test_excluded_keys (collector_class , exclude , out_key ):
943944 if not exclude and collector_class is not SyncDataCollector :
944945 pytest .skip ("defining _exclude_private_keys is not possible" )
945946
946947 def make_env ():
947- return ContinuousActionVecMockEnv ()
948+ return TransformedEnv ( ContinuousActionVecMockEnv (), InitTracker () )
948949
949950 dummy_env = make_env ()
950951 obs_spec = dummy_env .observation_spec ["observation" ]
951952 policy_module = nn .Linear (obs_spec .shape [- 1 ], dummy_env .action_spec .shape [- 1 ])
952- policy = Actor (policy_module , spec = dummy_env .action_spec )
953+ policy = TensorDictModule (
954+ policy_module , in_keys = ["observation" ], out_keys = ["action" ]
955+ )
956+ copier = TensorDictModule (lambda x : x , in_keys = ["observation" ], out_keys = [out_key ])
957+ policy = TensorDictSequential (policy , copier )
953958 policy_explore = OrnsteinUhlenbeckProcessWrapper (policy )
954959
955960 collector_kwargs = {
@@ -966,11 +971,13 @@ def make_env():
966971 collector = collector_class (** collector_kwargs )
967972 collector ._exclude_private_keys = exclude
968973 for b in collector :
969- keys = b .keys ()
974+ keys = set ( b .keys () )
970975 if exclude :
971976 assert not any (key .startswith ("_" ) for key in keys )
977+ assert out_key not in b .keys (True , True )
972978 else :
973979 assert any (key .startswith ("_" ) for key in keys )
980+ assert out_key in b .keys (True , True )
974981 break
975982 collector .shutdown ()
976983 dummy_env .close ()
0 commit comments