1010from torchrl .envs .common import _EnvWrapper
1111from torchrl .envs .utils import step_tensordict
1212
13+ __all__ = ["GymLikeEnv" , "default_info_dict_reader" ]
14+
15+
16+ class default_info_dict_reader :
17+ """
18+ Default info-key reader.
19+
20+ In cases where keys can be directly written to a tensordict (mostly if they abide to the
21+ tensordict shape), one simply needs to indicate the keys to be registered during
22+ instantiation.
23+
24+ Examples:
25+ >>> from torchrl.envs import GymWrapper, default_info_dict_reader
26+ >>> reader = default_info_dict_reader(["my_info_key"])
27+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
28+ >>> env = GymWrapper(gym.make("some_env-v0"))
29+ >>> env.set_info_dict_reader(info_dict_reader=reader)
30+ >>> tensordict = env.reset()
31+ >>> tensordict = env.rand_step(tensordict)
32+ >>> assert "my_info_key" in tensordict.keys()
33+
34+ """
35+
36+ def __init__ (self , keys = None ):
37+ if keys is None :
38+ keys = []
39+ self .keys = keys
40+
41+ def __call__ (self , info_dict : dict , tensordict : _TensorDict ) -> _TensorDict :
42+ for key in self .keys :
43+ if key in info_dict :
44+ tensordict [key ] = info_dict [key ]
45+ return tensordict
46+
1347
1448class GymLikeEnv (_EnvWrapper ):
15- info_keys = []
49+ _info_dict_reader : callable
1650
1751 """
1852 A gym-like env is an environment whose behaviour is similar to gym environments in what
@@ -25,7 +59,7 @@ class GymLikeEnv(_EnvWrapper):
2559
2660 where the outputs are the observation, reward and done state respectively.
2761 In this implementation, the info output is discarded (but specific keys can be read
28- by updating the `"info_keys" ` class attribute ).
62+ by updating info_dict_reader, see `set_info_dict_reader ` class method ).
2963
3064 By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless
3165 the first output is a dictionary. In that case, each observation output will be put at the corresponding
@@ -65,9 +99,7 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
6599 )
66100 tensordict_out .set ("reward" , reward )
67101 tensordict_out .set ("done" , done )
68- for key in self .info_keys :
69- data = info [0 ][key ]
70- tensordict_out .set (key , data )
102+ self .info_dict_reader (info , tensordict_out )
71103
72104 self .current_tensordict = step_tensordict (tensordict_out )
73105 return tensordict_out
@@ -100,6 +132,42 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
100132 )
101133 return step_outputs_tuple
102134
135+ def set_info_dict_reader (self , info_dict_reader : callable ) -> GymLikeEnv :
136+ """
137+ Sets an info_dict_reader function. This function should take as input an
138+ info_dict dictionary and the tensordict returned by the step function, and
139+ write values in an ad-hoc manner from one to the other.
140+
141+ Args:
142+ info_dict_reader (callable): a callable taking a input dictionary and
143+ output tensordict as arguments. This function should modify the
144+ tensordict in-place.
145+
146+ Returns: the same environment with the dict_reader registered.
147+
148+ Examples:
149+ >>> from torchrl.envs import GymWrapper, default_info_dict_reader
150+ >>> reader = default_info_dict_reader(["my_info_key"])
151+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
152+ >>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
153+ >>> tensordict = env.reset()
154+ >>> tensordict = env.rand_step(tensordict)
155+ >>> assert "my_info_key" in tensordict.keys()
156+
157+ """
158+ self .info_dict_reader = info_dict_reader
159+ return self
160+
161+ @property
162+ def info_dict_reader (self ):
163+ if "_info_dict_reader" not in self .__dir__ ():
164+ self ._info_dict_reader = default_info_dict_reader ()
165+ return self ._info_dict_reader
166+
167+ @info_dict_reader .setter
168+ def info_dict_reader (self , value : callable ):
169+ self ._info_dict_reader = value
170+
103171 def __repr__ (self ) -> str :
104172 return (
105173 f"{ self .__class__ .__name__ } (env={ self ._env } , batch_size={ self .batch_size } )"
0 commit comments