1010
1111import torch
1212from tensordict .tensordict import LazyStackedTensorDict , TensorDictBase
13+ from tensordict .utils import expand_right
1314
1415from torchrl .data .utils import DEVICE_TYPING
1516
@@ -351,7 +352,11 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
351352 tensordict = tensordict .clone (recurse = False )
352353 tensordict .batch_size = []
353354 try :
354- priority = tensordict .get (self .priority_key ).item ()
355+ priority = tensordict .get (self .priority_key )
356+ if priority .numel () > 1 :
357+ priority = _reduce (priority , self ._sampler .reduction )
358+ else :
359+ priority = priority .item ()
355360 except ValueError :
356361 raise ValueError (
357362 f"Found a priority key of size"
@@ -378,7 +383,16 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
378383 tensordicts = tensordicts .clone (recurse = False )
379384 else :
380385 tensordicts = tensordicts .contiguous ()
386+ # we keep track of the batch size to reinstantiate it when sampling
387+ if "_batch_size" in tensordicts .keys ():
388+ raise KeyError (
389+ "conflicting key '_batch_size'. Consider removing from data."
390+ )
391+ shape = torch .tensor (tensordicts .batch_size [1 :]).expand (
392+ tensordicts .batch_size [0 ], tensordicts .batch_dims - 1
393+ )
381394 tensordicts .batch_size = tensordicts .batch_size [:1 ]
395+ tensordicts .set ("_batch_size" , shape )
382396 tensordicts .set (
383397 "index" ,
384398 torch .zeros (
@@ -406,7 +420,13 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
406420 dtype = torch .float ,
407421 device = data .device ,
408422 )
409- self .update_priority (data .get ("index" ), priority )
423+ # if the index shape does not match the priority shape, we have expanded it.
424+ # we just take the first value
425+ index = data .get ("index" )
426+ while index .shape != priority .shape :
427+ # reduce index
428+ index = index [..., 0 ]
429+ self .update_priority (index , priority )
410430
411431 def sample (
412432 self , batch_size : int , include_info : bool = False , return_info : bool = False
@@ -429,6 +449,18 @@ def sample(
429449 if include_info :
430450 for k , v in info .items ():
431451 data .set (k , torch .tensor (v , device = data .device ), inplace = True )
452+ if "_batch_size" in data .keys ():
453+ # we need to reset the batch-size
454+ shape = data .pop ("_batch_size" )
455+ shape = shape [0 ]
456+ shape = torch .Size ([data .shape [0 ], * shape ])
457+ # we may need to update some values in the data
458+ for key , value in data .items ():
459+ if value .ndim >= len (shape ):
460+ continue
461+ value = expand_right (value , shape )
462+ data .set (key , value )
463+ data .batch_size = shape
432464 if return_info :
433465 return data , info
434466 return data
@@ -462,6 +494,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
462494 using multithreading.
463495 transform (Transform, optional): Transform to be executed when sample() is called.
464496 To chain transforms use the :obj:`Compose` class.
497+ reduction (str, optional): the reduction method for multidimensional
498+ tensordicts (ie stored trajectories). Can be one of "max", "min",
499+ "median" or "mean".
465500 """
466501
467502 def __init__ (
@@ -475,10 +510,13 @@ def __init__(
475510 pin_memory : bool = False ,
476511 prefetch : Optional [int ] = None ,
477512 transform : Optional ["Transform" ] = None , # noqa-F821
513+ reduction : Optional [str ] = "max" ,
478514 ) -> None :
479515 if storage is None :
480516 storage = ListStorage (max_size = 1_000 )
481- sampler = PrioritizedSampler (storage .max_size , alpha , beta , eps )
517+ sampler = PrioritizedSampler (
518+ storage .max_size , alpha , beta , eps , reduction = reduction
519+ )
482520 super (TensorDictPrioritizedReplayBuffer , self ).__init__ (
483521 priority_key = priority_key ,
484522 storage = storage ,
@@ -539,3 +577,16 @@ def __call__(self, list_of_tds):
539577 else :
540578 torch .stack (list_of_tds , 0 , out = self .out )
541579 return self .out
580+
581+
582+ def _reduce (tensor : torch .Tensor , reduction : str ):
583+ """Reduces a tensor given the reduction method."""
584+ if reduction == "max" :
585+ return tensor .max ().item ()
586+ elif reduction == "min" :
587+ return tensor .min ().item ()
588+ elif reduction == "mean" :
589+ return tensor .mean ().item ()
590+ elif reduction == "median" :
591+ return tensor .median ().item ()
592+ raise NotImplementedError (f"Unknown reduction method { reduction } " )
0 commit comments