@@ -2682,7 +2682,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
26822682
26832683 episode_specs = {}
26842684 if isinstance (reward_spec , CompositeSpec ):
2685-
26862685 # If reward_spec is a CompositeSpec, all in_keys should be keys of reward_spec
26872686 if not all (k in reward_spec .keys () for k in self .in_keys ):
26882687 raise KeyError ("Not all in_keys are present in ´reward_spec´" )
@@ -2697,7 +2696,6 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
26972696 episode_specs .update ({out_key : episode_spec })
26982697
26992698 else :
2700-
27012699 # If reward_spec is not a CompositeSpec, the only in_key should be ´reward´
27022700 if not set (self .in_keys ) == {"reward" }:
27032701 raise KeyError (
@@ -2882,3 +2880,106 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
28822880 if key in self .selected_keys
28832881 }
28842882 )
2883+
2884+
2885+ class TimeMaxPool (Transform ):
2886+ """Take the maximum value in each position over the last T observations.
2887+
2888+ This transform take the maximum value in each position for all in_keys tensors over the last T time steps.
2889+
2890+ Args:
2891+ in_keys (sequence of str, optional): input keys on which the max pool will be applied. Defaults to "observation" if left empty.
2892+ out_keys (sequence of str, optional): output keys where the output will be written. Defaults to `in_keys` if left empty.
2893+ T (int, optional): Number of time steps over which to apply max pooling.
2894+ """
2895+
2896+ inplace = False
2897+ invertible = False
2898+
2899+ def __init__ (
2900+ self ,
2901+ in_keys : Optional [Sequence [str ]] = None ,
2902+ out_keys : Optional [Sequence [str ]] = None ,
2903+ T : int = 1 ,
2904+ ):
2905+ if in_keys is None :
2906+ in_keys = ["observation" ]
2907+ super ().__init__ (in_keys = in_keys , out_keys = out_keys )
2908+ if T < 1 :
2909+ raise ValueError (
2910+ "TimeMaxPoolTranform T parameter should have a value greater or equal to one."
2911+ )
2912+ if len (self .in_keys ) != len (self .out_keys ):
2913+ raise ValueError (
2914+ "TimeMaxPoolTranform in_keys and out_keys don't have the same number of elements"
2915+ )
2916+ self .buffer_size = T
2917+ for in_key in self .in_keys :
2918+ buffer_name = f"_maxpool_buffer_{ in_key } "
2919+ setattr (
2920+ self ,
2921+ buffer_name ,
2922+ torch .nn .parameter .UninitializedBuffer (
2923+ device = torch .device ("cpu" ), dtype = torch .get_default_dtype ()
2924+ ),
2925+ )
2926+
2927+ def reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
2928+ """Resets _buffers."""
2929+ # Non-batched environments
2930+ if len (tensordict .batch_size ) < 1 or tensordict .batch_size [0 ] == 1 :
2931+ for in_key in self .in_keys :
2932+ buffer_name = f"_maxpool_buffer_{ in_key } "
2933+ buffer = getattr (self , buffer_name )
2934+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2935+ continue
2936+ buffer .fill_ (0.0 )
2937+
2938+ # Batched environments
2939+ else :
2940+ _reset = tensordict .get (
2941+ "_reset" ,
2942+ torch .ones (
2943+ tensordict .batch_size ,
2944+ dtype = torch .bool ,
2945+ device = tensordict .device ,
2946+ ),
2947+ )
2948+ for in_key in self .in_keys :
2949+ buffer_name = f"_maxpool_buffer_{ in_key } "
2950+ buffer = getattr (self , buffer_name )
2951+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2952+ continue
2953+ buffer [:, _reset ] = 0.0
2954+
2955+ return tensordict
2956+
2957+ def _make_missing_buffer (self , data , buffer_name ):
2958+ buffer = getattr (self , buffer_name )
2959+ buffer .materialize ((self .buffer_size ,) + data .shape )
2960+ buffer = buffer .to (data .dtype ).to (data .device ).zero_ ()
2961+ setattr (self , buffer_name , buffer )
2962+
2963+ def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
2964+ """Update the episode tensordict with max pooled keys."""
2965+ for in_key , out_key in zip (self .in_keys , self .out_keys ):
2966+ # Lazy init of buffers
2967+ buffer_name = f"_maxpool_buffer_{ in_key } "
2968+ buffer = getattr (self , buffer_name )
2969+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
2970+ data = tensordict [in_key ]
2971+ self ._make_missing_buffer (data , buffer_name )
2972+ # shift obs 1 position to the right
2973+ buffer .copy_ (torch .roll (buffer , shifts = 1 , dims = 0 ))
2974+ # add new obs
2975+ buffer [0 ].copy_ (tensordict [in_key ])
2976+ # apply max pooling
2977+ pooled_tensor , _ = buffer .max (dim = 0 )
2978+ # add to tensordict
2979+ tensordict .set (out_key , pooled_tensor )
2980+
2981+ return tensordict
2982+
2983+ @_apply_to_composite
2984+ def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
2985+ return observation_spec
0 commit comments