Skip to content

Commit 3428d3f

Browse files
committed
fix the fixes 3
1 parent 35afb9c commit 3428d3f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchrl/collectors/collectors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2783,8 +2783,9 @@ def _setup_multi_policy_and_weights(
27832783
)
27842784

27852785
# Extract weights from policy
2786+
# Use .data to avoid gradient tracking (can't serialize tensors with requires_grad)
27862787
weights = (
2787-
TensorDict.from_module(policy)
2788+
TensorDict.from_module(policy, as_module=True).data
27882789
if isinstance(policy, nn.Module)
27892790
else TensorDict()
27902791
)

0 commit comments

Comments
 (0)