@@ -1547,19 +1547,26 @@ def __init__(
15471547 if cat_dim > 0 :
15481548 raise ValueError (self ._CAT_DIM_ERR )
15491549 self .cat_dim = cat_dim
1550+ for in_key in self .in_keys :
1551+ buffer_name = f"_cat_buffers_{ in_key } "
1552+ setattr (
1553+ self ,
1554+ buffer_name ,
1555+ torch .nn .parameter .UninitializedBuffer (
1556+ device = torch .device ("cpu" ), dtype = torch .get_default_dtype ()
1557+ ),
1558+ )
15501559
15511560 def reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
15521561 """Resets _buffers."""
15531562 # Non-batched environments
15541563 if len (tensordict .batch_size ) < 1 or tensordict .batch_size [0 ] == 1 :
15551564 for in_key in self .in_keys :
15561565 buffer_name = f"_cat_buffers_{ in_key } "
1557- try :
1558- buffer = getattr (self , buffer_name )
1559- buffer .fill_ (0.0 )
1560- except AttributeError :
1561- # we'll instantiate later, when needed
1562- pass
1566+ buffer = getattr (self , buffer_name )
1567+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1568+ continue
1569+ buffer .fill_ (0.0 )
15631570
15641571 # Batched environments
15651572 else :
@@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
15731580 )
15741581 for in_key in self .in_keys :
15751582 buffer_name = f"_cat_buffers_{ in_key } "
1576- try :
1577- buffer = getattr (self , buffer_name )
1578- buffer [_reset ] = 0.0
1579- except AttributeError :
1580- # we'll instantiate later, when needed
1581- pass
1583+ buffer = getattr (self , buffer_name )
1584+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1585+ continue
1586+ buffer [_reset ] = 0.0
15821587
15831588 return tensordict
15841589
@@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name):
15871592 d = shape [self .cat_dim ]
15881593 shape [self .cat_dim ] = d * self .N
15891594 shape = torch .Size (shape )
1590- self .register_buffer (
1591- buffer_name ,
1592- torch .zeros (
1593- shape ,
1594- dtype = data .dtype ,
1595- device = data .device ,
1596- ),
1597- )
1598- buffer = getattr (self , buffer_name )
1595+ getattr (self , buffer_name ).materialize (shape )
1596+ buffer = getattr (self , buffer_name ).to (data .dtype ).to (data .device ).zero_ ()
1597+ setattr (self , buffer_name , buffer )
15991598 return buffer
16001599
16011600 def _call (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
16051604 buffer_name = f"_cat_buffers_{ in_key } "
16061605 data = tensordict [in_key ]
16071606 d = data .size (self .cat_dim )
1608- try :
1609- buffer = getattr (self , buffer_name )
1607+ buffer = getattr (self , buffer_name )
1608+ if isinstance (buffer , torch .nn .parameter .UninitializedBuffer ):
1609+ buffer = self ._make_missing_buffer (data , buffer_name )
1610+ else :
16101611 # shift obs 1 position to the right
16111612 buffer .copy_ (torch .roll (buffer , shifts = - d , dims = self .cat_dim ))
1612- except AttributeError :
1613- buffer = self ._make_missing_buffer (data , buffer_name )
16141613 # add new obs
16151614 idx = self .cat_dim
16161615 if idx < 0 :
0 commit comments