File tree Expand file tree Collapse file tree 1 file changed +17
-1
lines changed Expand file tree Collapse file tree 1 file changed +17
-1
lines changed Original file line number Diff line number Diff line change @@ -1192,7 +1192,23 @@ def _setup_replay_buffer(
11921192 def _setup_policy_and_weights (self , policy : TensorDictModule | Callable ) -> None :
11931193 """Set up policy, wrapped policy, and extract weights."""
11941194 self ._original_policy = policy
1195- policy , self .get_weights_fn = self ._get_policy_and_device (policy = policy )
1195+
1196+ # Check if policy has meta-device parameters (sent from weight sync schemes)
1197+ # In that case, skip device placement - weights will come from the receiver
1198+ has_meta_params = False
1199+ if isinstance (policy , nn .Module ):
1200+ for p in policy .parameters ():
1201+ if p .device .type == "meta" :
1202+ has_meta_params = True
1203+ break
1204+
1205+ if has_meta_params :
1206+ # Skip device placement for meta policies - schemes handle weight application
1207+ # Policy stays as-is, weights will be applied by the receiver
1208+ self .get_weights_fn = lambda : TensorDict .from_module (policy ).data
1209+ else :
1210+ # Normal path: move policy to correct device
1211+ policy , self .get_weights_fn = self ._get_policy_and_device (policy = policy )
11961212
11971213 if not self .trust_policy :
11981214 self .policy = policy
You can’t perform that action at this time.
0 commit comments