@@ -297,7 +297,15 @@ def set(
297
297
def get (self , index : Union [int , Sequence [int ], slice ]) -> Any :
298
298
if isinstance (index , (INT_CLASSES , slice )):
299
299
return self ._storage [index ]
300
+ elif isinstance (index , tuple ):
301
+ if len (index ) > 1 :
302
+ raise RuntimeError (
303
+ f"{ type (self ).__name__ } can only be indexed with one-length tuples."
304
+ )
305
+ return self .get (index [0 ])
300
306
else :
307
+ if isinstance (index , torch .Tensor ) and index .device .type != "cpu" :
308
+ index = index .cpu ().tolist ()
301
309
return [self ._storage [i ] for i in index ]
302
310
303
311
def __len__ (self ):
@@ -353,6 +361,77 @@ def contains(self, item):
353
361
raise NotImplementedError (f"type { type (item )} is not supported yet." )
354
362
355
363
364
+ class LazyStackStorage (ListStorage ):
365
+ """A ListStorage that returns LazyStackTensorDict instances.
366
+
367
+ This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
368
+ It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
369
+ lazily stacking items when queried.
370
+ This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
371
+ Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
372
+ Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
373
+ the size of the buffer.
374
+
375
+ If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
376
+ (see :mod:`~torch.nested`).
377
+
378
+ Args:
379
+ max_size (int, optional): the maximum number of elements stored in the storage.
380
+ If not provided, an unlimited storage is created.
381
+
382
+ Keyword Args:
383
+ compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
384
+ the cost of being executable in multiprocessed settings.
385
+ stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`.
386
+
387
+ Examples:
388
+ >>> import torch
389
+ >>> from torchrl.data import ReplayBuffer, LazyStackStorage
390
+ >>> from tensordict import TensorDict
391
+ >>> _ = torch.manual_seed(0)
392
+ >>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
393
+ >>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
394
+ >>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
395
+ >>> _ = rb.add(data0)
396
+ >>> _ = rb.add(data1)
397
+ >>> rb.sample(10)
398
+ LazyStackedTensorDict(
399
+ fields={
400
+ a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
401
+ b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
402
+ c: NonTensorStack(
403
+ ['another string!', 'another string!', 'another st...,
404
+ batch_size=torch.Size([10]),
405
+ device=None)},
406
+ exclusive_fields={
407
+ },
408
+ batch_size=torch.Size([10]),
409
+ device=None,
410
+ is_shared=False,
411
+ stack_dim=0)
412
+ """
413
+
414
+ def __init__ (
415
+ self ,
416
+ max_size : int | None = None ,
417
+ * ,
418
+ compilable : bool = False ,
419
+ stack_dim : int = - 1 ,
420
+ ):
421
+ super ().__init__ (max_size = max_size , compilable = compilable )
422
+ self .stack_dim = stack_dim
423
+
424
+ def get (self , index : Union [int , Sequence [int ], slice ]) -> Any :
425
+ out = super ().get (index = index )
426
+ if isinstance (out , list ):
427
+ stack_dim = self .stack_dim
428
+ if stack_dim < 0 :
429
+ stack_dim = out [0 ].ndim + 1 + stack_dim
430
+ out = LazyStackedTensorDict (* out , stack_dim = stack_dim )
431
+ return out
432
+ return out
433
+
434
+
356
435
class TensorStorage (Storage ):
357
436
"""A storage for tensors and tensordicts.
358
437
0 commit comments