Skip to content

Commit 81b1ed5

Browse files
committed
fix the fixes 4
1 parent 3428d3f commit 81b1ed5

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torchrl/collectors/collectors.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,23 @@ def _setup_replay_buffer(
11921192
def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None:
11931193
"""Set up policy, wrapped policy, and extract weights."""
11941194
self._original_policy = policy
1195-
policy, self.get_weights_fn = self._get_policy_and_device(policy=policy)
1195+
1196+
# Check if policy has meta-device parameters (sent from weight sync schemes)
1197+
# In that case, skip device placement - weights will come from the receiver
1198+
has_meta_params = False
1199+
if isinstance(policy, nn.Module):
1200+
for p in policy.parameters():
1201+
if p.device.type == "meta":
1202+
has_meta_params = True
1203+
break
1204+
1205+
if has_meta_params:
1206+
# Skip device placement for meta policies - schemes handle weight application
1207+
# Policy stays as-is, weights will be applied by the receiver
1208+
self.get_weights_fn = lambda: TensorDict.from_module(policy).data
1209+
else:
1210+
# Normal path: move policy to correct device
1211+
policy, self.get_weights_fn = self._get_policy_and_device(policy=policy)
11961212

11971213
if not self.trust_policy:
11981214
self.policy = policy

0 commit comments

Comments
 (0)