1212from typing import Iterator , List , Optional , Tuple
1313
1414import torch
15- from tensordict import TensorDict , TensorDictBase
15+ from tensordict import is_tensor_collection , TensorDict , TensorDictBase
1616
1717from tensordict .nn import TensorDictModule , TensorDictModuleBase , TensorDictParams
1818from torch import nn
@@ -248,6 +248,13 @@ def convert_to_functional(
248248 # For buffers, a cloned expansion (or equivalently a repeat) is returned.
249249
250250 def _compare_and_expand (param ):
251+ if is_tensor_collection (param ):
252+ return param ._apply_nest (
253+ _compare_and_expand ,
254+ batch_size = [expand_dim , * param .shape ],
255+ filter_empty = False ,
256+ call_on_nested = True ,
257+ )
251258 if not isinstance (param , nn .Parameter ):
252259 buffer = param .expand (expand_dim , * param .shape ).clone ()
253260 return buffer
@@ -257,7 +264,7 @@ def _compare_and_expand(param):
257264 # is called:
258265 return expanded_param
259266 else :
260- p_out = param .repeat (expand_dim , * [ 1 for _ in param .shape ] )
267+ p_out = param .expand (expand_dim , * param .shape ). clone ( )
261268 p_out = nn .Parameter (
262269 p_out .uniform_ (
263270 p_out .min ().item (), p_out .max ().item ()
@@ -267,7 +274,9 @@ def _compare_and_expand(param):
267274
268275 params = TensorDictParams (
269276 params .apply (
270- _compare_and_expand , batch_size = [expand_dim , * params .shape ]
277+ _compare_and_expand ,
278+ batch_size = [expand_dim , * params .shape ],
279+ call_on_nested = True ,
271280 ),
272281 no_convert = True ,
273282 )
0 commit comments