Skip to content

Commit 9987d92

Browse files
Vincent Moensvmoens
authored andcommitted
[BugFix] Fix batch-size expansion in functionalization (#1959)
1 parent bfb4037 commit 9987d92

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

torchrl/objectives/common.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Iterator, List, Optional, Tuple
1313

1414
import torch
15-
from tensordict import TensorDict, TensorDictBase
15+
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
1616

1717
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
1818
from torch import nn
@@ -248,6 +248,13 @@ def convert_to_functional(
248248
# For buffers, a cloned expansion (or equivalently a repeat) is returned.
249249

250250
def _compare_and_expand(param):
251+
if is_tensor_collection(param):
252+
return param._apply_nest(
253+
_compare_and_expand,
254+
batch_size=[expand_dim, *param.shape],
255+
filter_empty=False,
256+
call_on_nested=True,
257+
)
251258
if not isinstance(param, nn.Parameter):
252259
buffer = param.expand(expand_dim, *param.shape).clone()
253260
return buffer
@@ -257,7 +264,7 @@ def _compare_and_expand(param):
257264
# is called:
258265
return expanded_param
259266
else:
260-
p_out = param.repeat(expand_dim, *[1 for _ in param.shape])
267+
p_out = param.expand(expand_dim, *param.shape).clone()
261268
p_out = nn.Parameter(
262269
p_out.uniform_(
263270
p_out.min().item(), p_out.max().item()
@@ -267,7 +274,9 @@ def _compare_and_expand(param):
267274

268275
params = TensorDictParams(
269276
params.apply(
270-
_compare_and_expand, batch_size=[expand_dim, *params.shape]
277+
_compare_and_expand,
278+
batch_size=[expand_dim, *params.shape],
279+
call_on_nested=True,
271280
),
272281
no_convert=True,
273282
)

0 commit comments

Comments
 (0)