Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3537
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 Cancelled JobsAs of commit 5e2e9a6 with merge base 4d2c3cb ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
vmoens
left a comment
There was a problem hiding this comment.
Excellent first attempt!
Let's try to move most of this to core!
torchrl/modules/models/
gp.py # BoTorchGPWorldModel (renamed: GPWorldModel?)
rbf_controller.py # RBFController
torchrl/objectives/
pilco.py # SaturatingCost (the generic cost module)
sota-implementations/pilco/
pilco.py # Training loop (stays here)
utils.py # make_env, pendulum_cost (thin wrappers, stays here)
config.yaml # Config (stays here)
Missing tests:
- Unit tests for RBFController moment matching (forward pass, squash_sin)
- Unit tests for BoTorchGPWorldModel (fit, deterministic_forward, uncertain_forward)
- At minimum a smoke test for the full PILCO loop (see workflow in sota-implementations CI workflow)
- Numerical validation against the reference implementation (the author credits nrontsis/PILCO) if possible - ok if not
There are no docs. No docstrings on any class or method beyond the one-line pendulum_cost docstring. For core components, all public methods need proper docstrings with shapes documented (especially the moment matching formulas which are dense linear algebra). Docs must be linked in docs/source/reference/...
Avoid single letter variables unless they're indices (for in in range(...)) which are heavily used throughout the moment matching code (m, s, c, B, D, L, U, Q, t, z). These follow the paper notation, which is fine, but in core they should have comments referencing which equation in the paper each block corresponds to.
policy_for_env closure (pilco.py lines 166-200) -- this is an ad-hoc bridge between the Gaussian policy interface and a standard env that expects deterministic actions. In core, this should be a proper transform or wrapper (e.g., MeanActionSelector or similar) rather than a closure rebuilt every epoch.
| return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) | ||
|
|
||
|
|
||
| class BoTorchGPWorldModel(nn.Module): |
There was a problem hiding this comment.
If properly documented i'm happy with having this in core!
| return observation_mean + delta_mean, torch.diag_embed(delta_std**2) | ||
|
|
||
|
|
||
| class ImaginedEnv(ModelBasedEnvBase): |
There was a problem hiding this comment.
Ditto, maybe we want this in core.
How different is it from DreamerEnv? Can we blend the two together? (ok if we want to keep them separated)
| return out | ||
|
|
||
|
|
||
| class RBFController(nn.Module): |
There was a problem hiding this comment.
ditto happy to have it in core
| for a in range(self.obs_dim): | ||
| for b in range(self.obs_dim): |
There was a problem hiding this comment.
a lot in here can be vectorized
| else: | ||
| return self.deterministic_forward(action, observation) | ||
|
|
||
| def freeze_and_detach(self) -> None: |
| invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) | ||
| trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) | ||
|
|
||
| pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() |
There was a problem hiding this comment.
avoid .item()
Call .tolist() if absolutely necessary before. I think a plain tensor should work. .item() breaks compile and requires cuda sync.
| t = torch.linalg.solve(B_mat, iN.mT).mT | ||
|
|
||
| exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) | ||
| detB = torch.linalg.det(B_mat) |
There was a problem hiding this comment.
torch.linalg.slogdet would be numerically more stable (and is already partially used via the Cholesky log-det pattern elsewhere in the same code)
| scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) | ||
| lb = scaled_exp * beta.unsqueeze(0) | ||
|
|
||
| det_B = torch.linalg.det(B_mat) |
There was a problem hiding this comment.
ditto - let's think in logs!
| batch_size, num_train_pts, num_train_pts | ||
| ) | ||
|
|
||
| det_R_ab = torch.linalg.det(R_ab) |
| from botorch.fit import fit_gpytorch_mll | ||
|
|
||
| from botorch.models import ModelListGP, SingleTaskGP |
There was a problem hiding this comment.
botorch / gpytorch need to be added to the repo's dependencies as optional deps
|
Hi @PSXBRosa , Thank you for your work implementing PILCO for torchrl. A few days ago I opened a discussion about adding MC-PILCO, discussion n 3538, and reading through @vmoens ' review it's clear we're both going to depend on the same core primitives ( A couple of options as I see it:
Happy to go with whatever works best for you. If you want to discuss further feel free to reply here or reach out on discord (cabesamotora) ! |
|
Hi @alektebel, thanks for reaching out. #1 seems like the cleanest approach. I'm in the middle of a move right now so I haven't had much time for the PR this week, but I've started on vmoens' comments and plan to have the current issues resolved by the end of next week. My only progress so far is moving the loss to core. I can push this as a WIP if you'd like to take a look? How do you envision building on top of the current classes? Do you have a specific inheritance plan in mind? I'll ping you on discord. |
Description
This PR introduces the implementation of the PILCO (Probabilistic Inference for Learning Control) algorithm to TorchRL.
Key details of the implementation:
Motivation and Context
PILCO is a highly sample-efficient model-based reinforcement learning algorithm, making it a valuable addition to the library's algorithm suite.
close #3513
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!