3030from torchrl .data .replay_buffers .writers import RoundRobinWriter
3131
3232
33- collate_fn_dict = {
34- ListStorage : lambda x : torch .stack (x , 0 ),
35- LazyTensorStorage : lambda x : x ,
36- LazyMemmapStorage : lambda x : x ,
37- None : lambda x : torch .stack (x , 0 ),
38- }
33+ # collate_fn_dict = {
34+ # ListStorage: lambda x: torch.stack(x, 0),
35+ # LazyTensorStorage: lambda x: x,
36+ # LazyMemmapStorage: lambda x: x,
37+ # None: lambda x: torch.stack(x, 0),
38+ # }
3939
4040
4141@pytest .mark .parametrize (
5454@pytest .mark .parametrize ("size" , [3 , 100 ])
5555class TestPrototypeBuffers :
5656 def _get_rb (self , rb_type , size , sampler , writer , storage ):
57- collate_fn = collate_fn_dict [storage ]
5857
5958 if storage is not None :
6059 storage = storage (size )
@@ -65,9 +64,7 @@ def _get_rb(self, rb_type, size, sampler, writer, storage):
6564
6665 sampler = sampler (** sampler_args )
6766 writer = writer ()
68- rb = rb_type (
69- collate_fn = collate_fn , storage = storage , sampler = sampler , writer = writer
70- )
67+ rb = rb_type (storage = storage , sampler = sampler , writer = writer )
7168 return rb
7269
7370 def _get_datum (self , rb_type ):
@@ -192,7 +189,6 @@ def test_prototype_prb(priority_key, contiguous, device):
192189 np .random .seed (0 )
193190 rb = rb_prototype .TensorDictReplayBuffer (
194191 sampler = samplers .PrioritizedSampler (5 , alpha = 0.7 , beta = 0.9 ),
195- collate_fn = None if contiguous else lambda x : torch .stack (x , 0 ),
196192 priority_key = priority_key ,
197193 )
198194 td1 = TensorDict (
@@ -271,7 +267,6 @@ def test_rb_prototype_trajectories(stack):
271267 alpha = 0.7 ,
272268 beta = 0.9 ,
273269 ),
274- collate_fn = lambda x : torch .stack (x , 0 ),
275270 priority_key = "td_error" ,
276271 )
277272 rb .extend (traj_td )
@@ -315,7 +310,6 @@ class TestBuffers:
315310 _default_params_td_prb = {"alpha" : 0.8 , "beta" : 0.9 }
316311
317312 def _get_rb (self , rbtype , size , storage , prefetch ):
318- collate_fn = collate_fn_dict [storage ]
319313 if storage is not None :
320314 storage = storage (size )
321315 if rbtype is ReplayBuffer :
@@ -328,13 +322,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
328322 params = self ._default_params_td_prb
329323 else :
330324 raise NotImplementedError (rbtype )
331- rb = rbtype (
332- size = size ,
333- storage = storage ,
334- prefetch = prefetch ,
335- collate_fn = collate_fn ,
336- ** params
337- )
325+ rb = rbtype (size = size , storage = storage , prefetch = prefetch , ** params )
338326 return rb
339327
340328 def _get_datum (self , rbtype ):
@@ -460,7 +448,6 @@ def test_prb(priority_key, contiguous, device):
460448 5 ,
461449 alpha = 0.7 ,
462450 beta = 0.9 ,
463- collate_fn = None if contiguous else lambda x : torch .stack (x , 0 ),
464451 priority_key = priority_key ,
465452 )
466453 td1 = TensorDict (
@@ -537,7 +524,6 @@ def test_rb_trajectories(stack):
537524 5 ,
538525 alpha = 0.7 ,
539526 beta = 0.9 ,
540- collate_fn = lambda x : torch .stack (x , 0 ),
541527 priority_key = "td_error" ,
542528 )
543529 rb .extend (traj_td )
@@ -565,10 +551,14 @@ def test_shared_storage_prioritized_sampler():
565551 sampler1 = PrioritizedSampler (max_capacity = n , alpha = 0.7 , beta = 1.1 )
566552
567553 rb0 = rb_prototype .ReplayBuffer (
568- storage = storage , writer = writer , sampler = sampler0 , collate_fn = lambda x : x
554+ storage = storage ,
555+ writer = writer ,
556+ sampler = sampler0 ,
569557 )
570558 rb1 = rb_prototype .ReplayBuffer (
571- storage = storage , writer = writer , sampler = sampler1 , collate_fn = lambda x : x
559+ storage = storage ,
560+ writer = writer ,
561+ sampler = sampler1 ,
572562 )
573563
574564 data = TensorDict ({"a" : torch .arange (50 )}, [50 ])
@@ -593,9 +583,11 @@ def test_legacy_rb_does_not_attach():
593583 storage = LazyMemmapStorage (n )
594584 writer = RoundRobinWriter ()
595585 sampler = RandomSampler ()
596- rb = ReplayBuffer (storage = storage , size = n , prefetch = 0 , collate_fn = lambda x : x )
586+ rb = ReplayBuffer (storage = storage , size = n , prefetch = 0 )
597587 prb = rb_prototype .ReplayBuffer (
598- storage = storage , writer = writer , sampler = sampler , collate_fn = lambda x : x
588+ storage = storage ,
589+ writer = writer ,
590+ sampler = sampler ,
599591 )
600592
601593 assert len (storage ._attached_entities ) == 1
0 commit comments