@@ -155,19 +155,20 @@ def _map_to_cpu_if_needed(x):
155155def _make_meta_policy (policy : nn .Module ) -> nn .Module :
156156 """Create policy structure with parameters on meta device.
157157
158- This is used when policy_factory is provided but we still want to send
159- the policy structure to workers. The actual weights will be sent via queue .
158+ This is used with weight sync schemes to send policy structure without weights.
159+ The actual weights are distributed by the schemes .
160160
161161 Args:
162162 policy: Policy module to extract structure from.
163163
164164 Returns:
165- A copy of the policy with all parameters on meta device.
165+ A copy of the policy with all parameters on meta device and requires_grad=False .
166166 """
167167
168168 def _cast (p , param_maybe_buffer ):
169169 if isinstance (param_maybe_buffer , Parameter ):
170- return Parameter (p )
170+ # Create parameter without gradients to avoid serialization issues
171+ return Parameter (p , requires_grad = False )
171172 if isinstance (param_maybe_buffer , Buffer ):
172173 return Buffer (p )
173174 return p
@@ -182,19 +183,20 @@ def _cast(p, param_maybe_buffer):
182183def _make_meta_policy (policy : nn .Module ) -> nn .Module : # noqa: F811
183184 """Create policy structure with parameters on meta device.
184185
185- This is used when policy_factory is provided but we still want to send
186- the policy structure to workers. The actual weights will be sent via queue .
186+ This is used with weight sync schemes to send policy structure without weights.
187+ The actual weights are distributed by the schemes .
187188
188189 Args:
189190 policy: Policy module to extract structure from.
190191
191192 Returns:
192- A copy of the policy with all parameters on meta device.
193+ A copy of the policy with all parameters on meta device and requires_grad=False .
193194 """
194195
195196 def _cast (p , param_maybe_buffer ):
196197 if isinstance (param_maybe_buffer , Parameter ):
197- return Parameter (p )
198+ # Create parameter without gradients to avoid serialization issues
199+ return Parameter (p , requires_grad = False )
198200 return p
199201
200202 param_and_buf = TensorDict .from_module (policy , as_module = True )
@@ -3142,8 +3144,11 @@ def _run_processes(self) -> None:
31423144 # Schemes handle weight distribution on worker side
31433145 if any (policy_factory ):
31443146 policy_to_send = None # Factory will create policy in worker
3147+ elif policy is not None :
3148+ # Send meta-device policy (empty structure) - schemes apply weights
3149+ policy_to_send = _make_meta_policy (policy )
31453150 else :
3146- policy_to_send = policy # Stateless - schemes apply weights
3151+ policy_to_send = None
31473152 cm = contextlib .nullcontext ()
31483153 else :
31493154 # With weight updater, use in-place weight replacement
0 commit comments