Skip to content

Commit 35afb9c

Browse files
committed
fix the fixes 2
1 parent e3e0863 commit 35afb9c

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

torchrl/collectors/collectors.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,20 @@ def _map_to_cpu_if_needed(x):
155155
def _make_meta_policy(policy: nn.Module) -> nn.Module:
156156
"""Create policy structure with parameters on meta device.
157157
158-
This is used when policy_factory is provided but we still want to send
159-
the policy structure to workers. The actual weights will be sent via queue.
158+
This is used with weight sync schemes to send policy structure without weights.
159+
The actual weights are distributed by the schemes.
160160
161161
Args:
162162
policy: Policy module to extract structure from.
163163
164164
Returns:
165-
A copy of the policy with all parameters on meta device.
165+
A copy of the policy with all parameters on meta device and requires_grad=False.
166166
"""
167167

168168
def _cast(p, param_maybe_buffer):
169169
if isinstance(param_maybe_buffer, Parameter):
170-
return Parameter(p)
170+
# Create parameter without gradients to avoid serialization issues
171+
return Parameter(p, requires_grad=False)
171172
if isinstance(param_maybe_buffer, Buffer):
172173
return Buffer(p)
173174
return p
@@ -182,19 +183,20 @@ def _cast(p, param_maybe_buffer):
182183
def _make_meta_policy(policy: nn.Module) -> nn.Module: # noqa: F811
183184
"""Create policy structure with parameters on meta device.
184185
185-
This is used when policy_factory is provided but we still want to send
186-
the policy structure to workers. The actual weights will be sent via queue.
186+
This is used with weight sync schemes to send policy structure without weights.
187+
The actual weights are distributed by the schemes.
187188
188189
Args:
189190
policy: Policy module to extract structure from.
190191
191192
Returns:
192-
A copy of the policy with all parameters on meta device.
193+
A copy of the policy with all parameters on meta device and requires_grad=False.
193194
"""
194195

195196
def _cast(p, param_maybe_buffer):
196197
if isinstance(param_maybe_buffer, Parameter):
197-
return Parameter(p)
198+
# Create parameter without gradients to avoid serialization issues
199+
return Parameter(p, requires_grad=False)
198200
return p
199201

200202
param_and_buf = TensorDict.from_module(policy, as_module=True)
@@ -3142,8 +3144,11 @@ def _run_processes(self) -> None:
31423144
# Schemes handle weight distribution on worker side
31433145
if any(policy_factory):
31443146
policy_to_send = None # Factory will create policy in worker
3147+
elif policy is not None:
3148+
# Send meta-device policy (empty structure) - schemes apply weights
3149+
policy_to_send = _make_meta_policy(policy)
31453150
else:
3146-
policy_to_send = policy # Stateless - schemes apply weights
3151+
policy_to_send = None
31473152
cm = contextlib.nullcontext()
31483153
else:
31493154
# With weight updater, use in-place weight replacement

0 commit comments

Comments
 (0)