1212import numpy as np
1313import pytest
1414import torch
15- from _utils_internal import get_available_devices
15+ from _utils_internal import get_available_devices , make_tc
1616from tensordict import is_tensorclass , tensorclass
1717from tensordict .tensordict import assert_allclose_td , TensorDict , TensorDictBase
1818from torchrl .data import (
@@ -129,9 +129,14 @@ def test_add(self, rb_type, sampler, writer, storage, size):
129129 )
130130 data = self ._get_datum (rb_type )
131131 rb .add (data )
132- s = rb ._storage [0 ]
132+ s = rb .sample (1 )
133+ assert s .ndim , s
134+ s = s [0 ]
133135 if isinstance (s , TensorDictBase ):
134- assert (s == data .select (* s .keys ())).all ()
136+ s = s .select (* data .keys (True ), strict = False )
137+ data = data .select (* s .keys (True ), strict = False )
138+ assert (s == data ).all ()
139+ assert list (s .keys (True , True ))
135140 else :
136141 assert (s == data ).all ()
137142
@@ -373,14 +378,22 @@ def test_prototype_prb(priority_key, contiguous, device):
373378
374379
375380@pytest .mark .parametrize ("stack" , [False , True ])
381+ @pytest .mark .parametrize ("datatype" , ["tc" , "tb" ])
376382@pytest .mark .parametrize ("reduction" , ["min" , "max" , "median" , "mean" ])
377- def test_replay_buffer_trajectories (stack , reduction ):
383+ def test_replay_buffer_trajectories (stack , reduction , datatype ):
378384 traj_td = TensorDict (
379385 {"obs" : torch .randn (3 , 4 , 5 ), "actions" : torch .randn (3 , 4 , 2 )},
380386 batch_size = [3 , 4 ],
381387 )
388+ if datatype == "tc" :
389+ c = make_tc (traj_td )
390+ traj_td = c (** traj_td , batch_size = traj_td .batch_size )
391+ assert is_tensorclass (traj_td )
392+ elif datatype != "tb" :
393+ raise NotImplementedError
394+
382395 if stack :
383- traj_td = torch .stack ([ td . to_tensordict () for td in traj_td ] , 0 )
396+ traj_td = torch .stack (list ( traj_td ) , 0 )
384397
385398 rb = TensorDictReplayBuffer (
386399 sampler = samplers .PrioritizedSampler (
@@ -394,6 +407,10 @@ def test_replay_buffer_trajectories(stack, reduction):
394407 )
395408 rb .extend (traj_td )
396409 sampled_td = rb .sample ()
410+ if datatype == "tc" :
411+ assert is_tensorclass (traj_td )
412+ return
413+
397414 sampled_td .set ("td_error" , torch .rand (sampled_td .shape ))
398415 rb .update_tensordict_priority (sampled_td )
399416 sampled_td = rb .sample (include_info = True )
@@ -510,9 +527,12 @@ def test_add(self, rbtype, storage, size, prefetch):
510527 rb = self ._get_rb (rbtype , storage = storage , size = size , prefetch = prefetch )
511528 data = self ._get_datum (rbtype )
512529 rb .add (data )
513- s = rb ._storage [0 ]
530+ s = rb .sample ( 1 ) [0 ]
514531 if isinstance (s , TensorDictBase ):
515- assert (s == data .select (* s .keys ())).all ()
532+ s = s .select (* data .keys (True ), strict = False )
533+ data = data .select (* s .keys (True ), strict = False )
534+ assert (s == data ).all ()
535+ assert list (s .keys (True , True ))
516536 else :
517537 assert (s == data ).all ()
518538
@@ -649,6 +669,7 @@ def test_prb(priority_key, contiguous, device):
649669 },
650670 batch_size = [3 ],
651671 ).to (device )
672+
652673 rb .extend (td1 )
653674 s = rb .sample ()
654675 assert s .batch_size == torch .Size ([5 ])
@@ -838,17 +859,29 @@ def test_insert_transform():
838859
839860@pytest .mark .parametrize ("transform" , transforms )
840861def test_smoke_replay_buffer_transform (transform ):
841- rb = ReplayBuffer (transform = transform (in_keys = "observation" ), batch_size = 1 )
862+ rb = TensorDictReplayBuffer (
863+ transform = transform (in_keys = ["observation" ]), batch_size = 1
864+ )
842865
843866 # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, [])
844- td = TensorDict ({"observation" : torch .randn (3 , 3 , 3 , 16 , 1 )}, [])
867+ td = TensorDict ({"observation" : torch .randn (3 , 3 , 3 , 16 , 3 )}, [])
845868 rb .add (td )
846- rb .sample ()
847869
848- rb ._transform = mock .MagicMock ()
849- rb ._transform .__len__ = lambda * args : 3
870+ m = mock .Mock ()
871+ m .side_effect = [td .unsqueeze (0 )]
872+ rb ._transform .forward = m
873+ # rb._transform.__len__ = lambda *args: 3
850874 rb .sample ()
851- assert rb ._transform .called
875+ assert rb ._transform .forward .called
876+
877+ # was_called = [False]
878+ # forward = rb._transform.forward
879+ # def new_forward(*args, **kwargs):
880+ # was_called[0] = True
881+ # return forward(*args, **kwargs)
882+ # rb._transform.forward = new_forward
883+ # rb.sample()
884+ # assert was_called[0]
852885
853886
854887transforms = [
0 commit comments