Skip to content

Commit 33777eb

Browse files
committed
lazy compile
1 parent 81b1ed5 commit 33777eb

File tree

1 file changed

+100
-17
lines changed

1 file changed

+100
-17
lines changed

torchrl/collectors/collectors.py

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,27 +1227,53 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None
12271227
else:
12281228
self.policy = self._wrapped_policy = policy
12291229

1230-
# Extract policy weights
1231-
if isinstance(self._wrapped_policy, nn.Module):
1230+
# Extract policy weights from the uncompiled policy
1231+
# Access _wrapped_policy_uncompiled directly to avoid triggering compilation
1232+
if isinstance(self._wrapped_policy_uncompiled, nn.Module):
12321233
self.policy_weights = TensorDict.from_module(
1233-
self._wrapped_policy, as_module=True
1234+
self._wrapped_policy_uncompiled, as_module=True
12341235
).data
12351236
else:
12361237
self.policy_weights = TensorDict()
12371238

1238-
# Apply compilation/cudagraph
1239-
if self.compiled_policy:
1240-
self._wrapped_policy = compile_with_warmup(
1241-
self._wrapped_policy, **self.compiled_policy_kwargs
1239+
# If policy doesn't have meta params, compile immediately
1240+
# Otherwise, defer until first use (after weights are loaded)
1241+
if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy):
1242+
self._wrapped_policy_compiled = self._compile_wrapped_policy(
1243+
self._wrapped_policy_uncompiled
12421244
)
1245+
1246+
def _compile_wrapped_policy(self, policy):
1247+
"""Apply compilation and/or cudagraph to a policy."""
1248+
if self.compiled_policy:
1249+
policy = compile_with_warmup(policy, **self.compiled_policy_kwargs)
12431250
if self.cudagraphed_policy:
1244-
self._wrapped_policy = CudaGraphModule(
1245-
self._wrapped_policy,
1251+
policy = CudaGraphModule(
1252+
policy,
12461253
in_keys=[],
12471254
out_keys=[],
12481255
device=self.policy_device,
12491256
**self.cudagraphed_policy_kwargs,
12501257
)
1258+
return policy
1259+
1260+
@property
1261+
def _wrapped_policy(self):
1262+
"""Returns the compiled policy, compiling it lazily if needed."""
1263+
if (policy := self._wrapped_policy_compiled) is None:
1264+
if self.compiled_policy or self.cudagraphed_policy:
1265+
policy = self._wrapped_policy_compiled = self._compile_wrapped_policy(
1266+
self._wrapped_policy_uncompiled
1267+
)
1268+
else:
1269+
policy = self._wrapped_policy_compiled = self._wrapped_policy_uncompiled
1270+
return policy
1271+
1272+
@_wrapped_policy.setter
1273+
def _wrapped_policy(self, value):
1274+
"""Allow setting the wrapped policy during initialization."""
1275+
self._wrapped_policy_uncompiled = value
1276+
self._wrapped_policy_compiled = None
12511277

12521278
def _apply_env_device(self) -> None:
12531279
"""Apply device to environment if specified."""
@@ -1425,22 +1451,57 @@ def _maybe_make_final_rollout(self, make_rollout: bool):
14251451
# erase all devices
14261452
self._final_rollout.clear_device_()
14271453

1454+
# Check if policy has meta-device parameters (not yet initialized)
1455+
has_meta_params = False
1456+
if hasattr(self, "_wrapped_policy_uncompiled") and isinstance(
1457+
self._wrapped_policy_uncompiled, nn.Module
1458+
):
1459+
for p in self._wrapped_policy_uncompiled.parameters():
1460+
if p.device.type == "meta":
1461+
has_meta_params = True
1462+
break
1463+
14281464
# If the policy has a valid spec, we use it
14291465
self._policy_output_keys = set()
14301466
if (
14311467
make_rollout
1432-
and hasattr(self._wrapped_policy, "spec")
1433-
and self._wrapped_policy.spec is not None
1434-
and all(v is not None for v in self._wrapped_policy.spec.values(True, True))
1468+
and hasattr(
1469+
self._wrapped_policy_uncompiled
1470+
if has_meta_params
1471+
else self._wrapped_policy,
1472+
"spec",
1473+
)
1474+
and (
1475+
self._wrapped_policy_uncompiled
1476+
if has_meta_params
1477+
else self._wrapped_policy
1478+
).spec
1479+
is not None
1480+
and all(
1481+
v is not None
1482+
for v in (
1483+
self._wrapped_policy_uncompiled
1484+
if has_meta_params
1485+
else self._wrapped_policy
1486+
).spec.values(True, True)
1487+
)
14351488
):
14361489
if any(
14371490
key not in self._final_rollout.keys(isinstance(key, tuple))
1438-
for key in self._wrapped_policy.spec.keys(True, True)
1491+
for key in (
1492+
self._wrapped_policy_uncompiled
1493+
if has_meta_params
1494+
else self._wrapped_policy
1495+
).spec.keys(True, True)
14391496
):
14401497
# if policy spec is non-empty, all the values are not None and the keys
14411498
# match the out_keys we assume the user has given all relevant information
14421499
# the policy could have more keys than the env:
1443-
policy_spec = self._wrapped_policy.spec
1500+
policy_spec = (
1501+
self._wrapped_policy_uncompiled
1502+
if has_meta_params
1503+
else self._wrapped_policy
1504+
).spec
14441505
if policy_spec.ndim < self._final_rollout.ndim:
14451506
policy_spec = policy_spec.expand(self._final_rollout.shape)
14461507
for key, spec in policy_spec.items(True, True):
@@ -1450,10 +1511,32 @@ def _maybe_make_final_rollout(self, make_rollout: bool):
14501511
self._final_rollout.set(key, spec.zero())
14511512
elif (
14521513
not make_rollout
1453-
and hasattr(self._wrapped_policy, "out_keys")
1454-
and self._wrapped_policy.out_keys
1514+
and hasattr(
1515+
self._wrapped_policy_uncompiled
1516+
if has_meta_params
1517+
else self._wrapped_policy,
1518+
"out_keys",
1519+
)
1520+
and (
1521+
self._wrapped_policy_uncompiled
1522+
if has_meta_params
1523+
else self._wrapped_policy
1524+
).out_keys
14551525
):
1456-
self._policy_output_keys = list(self._wrapped_policy.out_keys)
1526+
self._policy_output_keys = list(
1527+
(
1528+
self._wrapped_policy_uncompiled
1529+
if has_meta_params
1530+
else self._wrapped_policy
1531+
).out_keys
1532+
)
1533+
elif has_meta_params:
1534+
# Policy has meta params and no spec/out_keys - defer initialization
1535+
# Mark that we need to initialize later when weights are loaded
1536+
self._policy_output_keys = set()
1537+
if make_rollout:
1538+
# We'll populate keys on first actual rollout after weights are loaded
1539+
self._final_rollout_needs_init = True
14571540
else:
14581541
if make_rollout:
14591542
# otherwise, we perform a small number of steps with the policy to

0 commit comments

Comments
 (0)