@@ -2620,6 +2620,8 @@ class CatFrames(ObservationTransform):
26202620 reset indicator. Must be unique. If not provided, defaults to the
26212621 only reset key of the parent environment (if it has only one)
26222622 and raises an exception otherwise.
2623+ done_key (NestedKey, optional): the done key to be used as partial
2624+ done indicator. Must be unique. If not provided, defaults to ``"done"``.
26232625
26242626 Examples:
26252627 >>> from torchrl.envs.libs.gym import GymEnv
@@ -2700,6 +2702,7 @@ def __init__(
27002702 padding_value = 0 ,
27012703 as_inverse = False ,
27022704 reset_key : NestedKey | None = None ,
2705+ done_key : NestedKey | None = None ,
27032706 ):
27042707 if in_keys is None :
27052708 in_keys = IMAGE_KEYS
@@ -2733,6 +2736,19 @@ def __init__(
27332736 # keeps track of calls to _reset since it's only _call that will populate the buffer
27342737 self .as_inverse = as_inverse
27352738 self .reset_key = reset_key
2739+ self .done_key = done_key
2740+
2741+ @property
2742+ def done_key (self ):
2743+ done_key = self .__dict__ .get ("_done_key" , None )
2744+ if done_key is None :
2745+ done_key = "done"
2746+ self ._done_key = done_key
2747+ return done_key
2748+
2749+ @done_key .setter
2750+ def done_key (self , value ):
2751+ self ._done_key = value
27362752
27372753 @property
27382754 def reset_key (self ):
@@ -2829,15 +2845,6 @@ def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
28292845 # make linter happy. An exception has already been raised
28302846 raise NotImplementedError
28312847
2832- # # this duplicates the code below, but only for _reset values
2833- # if _all:
2834- # buffer.copy_(torch.roll(buffer_reset, shifts=-d, dims=dim))
2835- # buffer_reset = buffer
2836- # else:
2837- # buffer_reset = buffer[_reset] = torch.roll(
2838- # buffer_reset, shifts=-d, dims=dim
2839- # )
2840- # add new obs
28412848 if self .dim < 0 :
28422849 n = buffer_reset .ndimension () + self .dim
28432850 else :
@@ -2906,69 +2913,136 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
29062913 if i != tensordict .ndim - 1 :
29072914 tensordict = tensordict .transpose (tensordict .ndim - 1 , i )
29082915 # first sort the in_keys with strings and non-strings
2909- in_keys = list (
2910- zip (
2911- (in_key , out_key )
2912- for in_key , out_key in zip (self .in_keys , self .out_keys )
2913- if isinstance (in_key , str ) or len (in_key ) == 1
2914- )
2915- )
2916- in_keys += list (
2917- zip (
2918- (in_key , out_key )
2919- for in_key , out_key in zip (self .in_keys , self .out_keys )
2920- if not isinstance (in_key , str ) and not len (in_key ) == 1
2916+ keys = [
2917+ (in_key , out_key )
2918+ for in_key , out_key in zip (self .in_keys , self .out_keys )
2919+ if isinstance (in_key , str )
2920+ ]
2921+ keys += [
2922+ (in_key , out_key )
2923+ for in_key , out_key in zip (self .in_keys , self .out_keys )
2924+ if not isinstance (in_key , str )
2925+ ]
2926+
2927+ def unfold_done (done , N ):
2928+ prefix = (slice (None ),) * (tensordict .ndim - 1 )
2929+ reset = torch .cat (
2930+ [
2931+ torch .zeros_like (done [prefix + (slice (self .N - 1 ),)]),
2932+ torch .ones_like (done [prefix + (slice (1 ),)]),
2933+ done [prefix + (slice (None , - 1 ),)],
2934+ ],
2935+ tensordict .ndim - 1 ,
29212936 )
2922- )
2923- for in_key , out_key in zip (self .in_keys , self .out_keys ):
2937+ reset_unfold = reset .unfold (tensordict .ndim - 1 , self .N , 1 )
2938+ reset_unfold_slice = reset_unfold [..., - 1 ]
2939+ reset_unfold_list = [torch .zeros_like (reset_unfold_slice )]
2940+ for r in reversed (reset_unfold .unbind (- 1 )):
2941+ reset_unfold_list .append (r | reset_unfold_list [- 1 ])
2942+ reset_unfold_slice = reset_unfold_list [- 1 ]
2943+ reset_unfold = torch .stack (list (reversed (reset_unfold_list ))[1 :], - 1 )
2944+ reset = reset [prefix + (slice (self .N - 1 , None ),)]
2945+ reset [prefix + (0 ,)] = 1
2946+ return reset_unfold , reset
2947+
2948+ done = tensordict .get (("next" , self .done_key ))
2949+ done_mask , reset = unfold_done (done , self .N )
2950+
2951+ for in_key , out_key in keys :
29242952 # check if we have an obs in "next" that has already been processed.
29252953 # If so, we must add an offset
2926- data = tensordict .get (in_key )
2954+ data_orig = data = tensordict .get (in_key )
2955+ n_feat = data_orig .shape [data .ndim + self .dim ]
2956+ first_val = None
29272957 if isinstance (in_key , tuple ) and in_key [0 ] == "next" :
29282958 # let's get the out_key we have already processed
2929- prev_out_key = dict (zip (self .in_keys , self .out_keys ))[in_key [1 ]]
2930- prev_val = tensordict .get (prev_out_key )
2931- # the first item is located along `dim+1` at the last index of the
2932- # first time index
2933- idx = (
2934- [slice (None )] * (tensordict .ndim - 1 )
2935- + [0 ]
2936- + [..., - 1 ]
2937- + [slice (None )] * (abs (self .dim ) - 1 )
2959+ prev_out_key = dict (zip (self .in_keys , self .out_keys )).get (
2960+ in_key [1 ], None
29382961 )
2939- first_val = prev_val [tuple (idx )].unsqueeze (tensordict .ndim - 1 )
2940- data0 = [first_val ] * (self .N - 1 )
2941- if self .padding == "constant" :
2942- data0 = [
2943- torch .full_like (elt , self .padding_value ) for elt in data0 [:- 1 ]
2944- ] + data0 [- 1 :]
2945- elif self .padding == "same" :
2946- pass
2947- else :
2948- # make linter happy. An exception has already been raised
2949- raise NotImplementedError
2950- elif self .padding == "same" :
2951- idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2952- data0 = [data [tuple (idx )].unsqueeze (tensordict .ndim - 1 )] * (self .N - 1 )
2953- elif self .padding == "constant" :
2954- idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2955- data0 = [
2956- torch .full_like (data [tuple (idx )], self .padding_value ).unsqueeze (
2957- tensordict .ndim - 1
2962+ if prev_out_key is not None :
2963+ prev_val = tensordict .get (prev_out_key )
2964+ # n_feat = prev_val.shape[data.ndim + self.dim] // self.N
2965+ first_val = prev_val .unflatten (
2966+ data .ndim + self .dim , (self .N , n_feat )
29582967 )
2959- ] * (self .N - 1 )
2960- else :
2961- # make linter happy. An exception has already been raised
2962- raise NotImplementedError
2968+
2969+ idx = [slice (None )] * (tensordict .ndim - 1 ) + [0 ]
2970+ data0 = [
2971+ torch .full_like (data [tuple (idx )], self .padding_value ).unsqueeze (
2972+ tensordict .ndim - 1
2973+ )
2974+ ] * (self .N - 1 )
29632975
29642976 data = torch .cat (data0 + [data ], tensordict .ndim - 1 )
29652977
29662978 data = data .unfold (tensordict .ndim - 1 , self .N , 1 )
2979+
2980+ # Place -1 dim at self.dim place before squashing
2981+ done_mask_expand = expand_as_right (done_mask , data )
29672982 data = data .permute (
2968- * range (0 , data .ndim + self .dim ),
2983+ * range (0 , data .ndim + self .dim - 1 ),
2984+ - 1 ,
2985+ * range (data .ndim + self .dim - 1 , data .ndim - 1 ),
2986+ )
2987+ done_mask_expand = done_mask_expand .permute (
2988+ * range (0 , done_mask_expand .ndim + self .dim - 1 ),
29692989 - 1 ,
2970- * range (data .ndim + self .dim , data .ndim - 1 ),
2990+ * range (done_mask_expand .ndim + self .dim - 1 , done_mask_expand .ndim - 1 ),
29712991 )
2992+ if self .padding != "same" :
2993+ data = torch .where (done_mask_expand , self .padding_value , data )
2994+ else :
2995+ # TODO: This is a pretty bad implementation, could be
2996+ # made more efficient but it works!
2997+ reset_vals = list (data_orig [reset .squeeze (- 1 )].unbind (0 ))
2998+ j_ = float ("inf" )
2999+ reps = []
3000+ d = data .ndim + self .dim - 1
3001+ for j in done_mask_expand .sum (d ).sum (d ).view (- 1 ) // n_feat :
3002+ if j > j_ :
3003+ reset_vals = reset_vals [1 :]
3004+ reps .extend ([reset_vals [0 ]] * int (j ))
3005+ j_ = j
3006+ reps = torch .stack (reps )
3007+ data = torch .masked_scatter (data , done_mask_expand , reps .reshape (- 1 ))
3008+
3009+ if first_val is not None :
3010+ # Aggregate reset along last dim
3011+ reset = reset .any (- 1 , True )
3012+ rexp = reset .expand (* reset .shape [:- 1 ], n_feat )
3013+ rexp = torch .cat (
3014+ [
3015+ torch .zeros_like (
3016+ data0 [0 ].repeat_interleave (
3017+ len (data0 ), dim = tensordict .ndim - 1
3018+ ),
3019+ dtype = torch .bool ,
3020+ ),
3021+ rexp ,
3022+ ],
3023+ tensordict .ndim - 1 ,
3024+ )
3025+ rexp = rexp .unfold (tensordict .ndim - 1 , self .N , 1 )
3026+ rexp_orig = rexp
3027+ rexp = torch .cat ([rexp [..., 1 :], torch .zeros_like (rexp [..., - 1 :])], - 1 )
3028+ if self .padding == "same" :
3029+ rexp_orig = rexp_orig .flip (- 1 ).cumsum (- 1 ).flip (- 1 ).bool ()
3030+ rexp = rexp .flip (- 1 ).cumsum (- 1 ).flip (- 1 ).bool ()
3031+ rexp_orig = torch .cat (
3032+ [torch .zeros_like (rexp_orig [..., - 1 :]), rexp_orig [..., 1 :]], - 1
3033+ )
3034+ rexp = rexp .permute (
3035+ * range (0 , rexp .ndim + self .dim - 1 ),
3036+ - 1 ,
3037+ * range (rexp .ndim + self .dim - 1 , rexp .ndim - 1 ),
3038+ )
3039+ rexp_orig = rexp_orig .permute (
3040+ * range (0 , rexp_orig .ndim + self .dim - 1 ),
3041+ - 1 ,
3042+ * range (rexp_orig .ndim + self .dim - 1 , rexp_orig .ndim - 1 ),
3043+ )
3044+ data [rexp ] = first_val [rexp_orig ]
3045+ data = data .flatten (data .ndim + self .dim - 1 , data .ndim + self .dim )
29723046 tensordict .set (out_key , data )
29733047 if tensordict_orig is not tensordict :
29743048 tensordict_orig = tensordict .transpose (tensordict .ndim - 1 , i )
0 commit comments