@@ -1227,27 +1227,53 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None
12271227 else :
12281228 self .policy = self ._wrapped_policy = policy
12291229
1230- # Extract policy weights
1231- if isinstance (self ._wrapped_policy , nn .Module ):
1230+ # Extract policy weights from the uncompiled policy
1231+ # Access _wrapped_policy_uncompiled directly to avoid triggering compilation
1232+ if isinstance (self ._wrapped_policy_uncompiled , nn .Module ):
12321233 self .policy_weights = TensorDict .from_module (
1233- self ._wrapped_policy , as_module = True
1234+ self ._wrapped_policy_uncompiled , as_module = True
12341235 ).data
12351236 else :
12361237 self .policy_weights = TensorDict ()
12371238
1238- # Apply compilation/cudagraph
1239- if self .compiled_policy :
1240- self ._wrapped_policy = compile_with_warmup (
1241- self ._wrapped_policy , ** self .compiled_policy_kwargs
1239+ # If policy doesn't have meta params, compile immediately
1240+ # Otherwise, defer until first use (after weights are loaded)
1241+ if not has_meta_params and (self .compiled_policy or self .cudagraphed_policy ):
1242+ self ._wrapped_policy_compiled = self ._compile_wrapped_policy (
1243+ self ._wrapped_policy_uncompiled
12421244 )
1245+
1246+ def _compile_wrapped_policy (self , policy ):
1247+ """Apply compilation and/or cudagraph to a policy."""
1248+ if self .compiled_policy :
1249+ policy = compile_with_warmup (policy , ** self .compiled_policy_kwargs )
12431250 if self .cudagraphed_policy :
1244- self . _wrapped_policy = CudaGraphModule (
1245- self . _wrapped_policy ,
1251+ policy = CudaGraphModule (
1252+ policy ,
12461253 in_keys = [],
12471254 out_keys = [],
12481255 device = self .policy_device ,
12491256 ** self .cudagraphed_policy_kwargs ,
12501257 )
1258+ return policy
1259+
1260+ @property
1261+ def _wrapped_policy (self ):
1262+ """Returns the compiled policy, compiling it lazily if needed."""
1263+ if (policy := self ._wrapped_policy_compiled ) is None :
1264+ if self .compiled_policy or self .cudagraphed_policy :
1265+ policy = self ._wrapped_policy_compiled = self ._compile_wrapped_policy (
1266+ self ._wrapped_policy_uncompiled
1267+ )
1268+ else :
1269+ policy = self ._wrapped_policy_compiled = self ._wrapped_policy_uncompiled
1270+ return policy
1271+
1272+ @_wrapped_policy .setter
1273+ def _wrapped_policy (self , value ):
1274+ """Allow setting the wrapped policy during initialization."""
1275+ self ._wrapped_policy_uncompiled = value
1276+ self ._wrapped_policy_compiled = None
12511277
12521278 def _apply_env_device (self ) -> None :
12531279 """Apply device to environment if specified."""
@@ -1425,22 +1451,57 @@ def _maybe_make_final_rollout(self, make_rollout: bool):
14251451 # erase all devices
14261452 self ._final_rollout .clear_device_ ()
14271453
1454+ # Check if policy has meta-device parameters (not yet initialized)
1455+ has_meta_params = False
1456+ if hasattr (self , "_wrapped_policy_uncompiled" ) and isinstance (
1457+ self ._wrapped_policy_uncompiled , nn .Module
1458+ ):
1459+ for p in self ._wrapped_policy_uncompiled .parameters ():
1460+ if p .device .type == "meta" :
1461+ has_meta_params = True
1462+ break
1463+
14281464 # If the policy has a valid spec, we use it
14291465 self ._policy_output_keys = set ()
14301466 if (
14311467 make_rollout
1432- and hasattr (self ._wrapped_policy , "spec" )
1433- and self ._wrapped_policy .spec is not None
1434- and all (v is not None for v in self ._wrapped_policy .spec .values (True , True ))
1468+ and hasattr (
1469+ self ._wrapped_policy_uncompiled
1470+ if has_meta_params
1471+ else self ._wrapped_policy ,
1472+ "spec" ,
1473+ )
1474+ and (
1475+ self ._wrapped_policy_uncompiled
1476+ if has_meta_params
1477+ else self ._wrapped_policy
1478+ ).spec
1479+ is not None
1480+ and all (
1481+ v is not None
1482+ for v in (
1483+ self ._wrapped_policy_uncompiled
1484+ if has_meta_params
1485+ else self ._wrapped_policy
1486+ ).spec .values (True , True )
1487+ )
14351488 ):
14361489 if any (
14371490 key not in self ._final_rollout .keys (isinstance (key , tuple ))
1438- for key in self ._wrapped_policy .spec .keys (True , True )
1491+ for key in (
1492+ self ._wrapped_policy_uncompiled
1493+ if has_meta_params
1494+ else self ._wrapped_policy
1495+ ).spec .keys (True , True )
14391496 ):
14401497 # if policy spec is non-empty, all the values are not None and the keys
14411498 # match the out_keys we assume the user has given all relevant information
14421499 # the policy could have more keys than the env:
1443- policy_spec = self ._wrapped_policy .spec
1500+ policy_spec = (
1501+ self ._wrapped_policy_uncompiled
1502+ if has_meta_params
1503+ else self ._wrapped_policy
1504+ ).spec
14441505 if policy_spec .ndim < self ._final_rollout .ndim :
14451506 policy_spec = policy_spec .expand (self ._final_rollout .shape )
14461507 for key , spec in policy_spec .items (True , True ):
@@ -1450,10 +1511,32 @@ def _maybe_make_final_rollout(self, make_rollout: bool):
14501511 self ._final_rollout .set (key , spec .zero ())
14511512 elif (
14521513 not make_rollout
1453- and hasattr (self ._wrapped_policy , "out_keys" )
1454- and self ._wrapped_policy .out_keys
1514+ and hasattr (
1515+ self ._wrapped_policy_uncompiled
1516+ if has_meta_params
1517+ else self ._wrapped_policy ,
1518+ "out_keys" ,
1519+ )
1520+ and (
1521+ self ._wrapped_policy_uncompiled
1522+ if has_meta_params
1523+ else self ._wrapped_policy
1524+ ).out_keys
14551525 ):
1456- self ._policy_output_keys = list (self ._wrapped_policy .out_keys )
1526+ self ._policy_output_keys = list (
1527+ (
1528+ self ._wrapped_policy_uncompiled
1529+ if has_meta_params
1530+ else self ._wrapped_policy
1531+ ).out_keys
1532+ )
1533+ elif has_meta_params :
1534+ # Policy has meta params and no spec/out_keys - defer initialization
1535+ # Mark that we need to initialize later when weights are loaded
1536+ self ._policy_output_keys = set ()
1537+ if make_rollout :
1538+ # We'll populate keys on first actual rollout after weights are loaded
1539+ self ._final_rollout_needs_init = True
14571540 else :
14581541 if make_rollout :
14591542 # otherwise, we perform a small number of steps with the policy to
0 commit comments