You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An improved elliptical slice sampling implementation (#2426)
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.
Help us understand your motivation by explaining why you decided to make this change.
You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->
## Motivation
Elliptical slice sampling for truncated normal distributions (e.g., [Gessner et al., 2020](https://arxiv.org/abs/1910.09328)) requires constructing the active intervals corresponding to the intersection of the ellipse and the domain. One method constructing the active intervals is based on likelihood testing, which tests whether each intersection angle is active or not. The current BoTorch implementation follows this idea and takes $\mathcal{O}(m^2 d)$ time per iteration, where $d$ is the dimensionality and $m$ is the number of linear inequality constraints.
However, there exists a faster algorithm computing the active intervals in $\mathcal{O}(m \log m)$ time ([Wu and Gardner, 2024](https://arxiv.org/abs/2407.10449)). This PR implements Algorithm 3 in this paper and dramatically accelerates truncated normal sampling in high dimensions.
In addition, this PR also implements batch MCMC to launch multiple Markov chains, which further speeds up sampling by exploiting GPU parallelism better. Users can pass an additional argument `chains` to specify how many chains to run in parallel:
```
batch_sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b),
check_feasibility=True,
chains=100, # launch 100 Markov chains in parallel
)
samples = batch_samplers.draw(n) # returns 100 * n samples
```
### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?
Yes.
Pull Request resolved: #2426
Test Plan:
<!--
(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)
-->
1. This implementation passes all existing test cases in `test/utils/probability/test_lin_ess.py`
2. Though, two test cases [here](https://github.com/pytorch/botorch/blob/0455dc3945c5e89eefa2dfeb7abab8c4f4079b36/test/utils/probability/test_lin_ess.py#L363-L396) have to be modified. Because the function `self._find_active_intersections` is replaced with `self._find_active_intersection_angles`, which has a different output format.
3. I have added additional test cases for batch MCMC.
### Some Plots
Attached is a plot demonstrating that this new algorithm accelerates truncated normal sampling in high dimensions. This experiments is ran with `torch.float32`. I have experienced a similar speed-up with double precision.
<p>
<img src="https://github.com/pytorch/botorch/assets/37524685/9c86d293-69aa-4911-ad7e-9c4c0ead655f" width=49% />
<img src="https://github.com/pytorch/botorch/assets/37524685/004f5fc8-6b88-462d-9a34-7d20b7e59a32" width=49% />
</p>
The following is a plot of samples by running 500 Markov chains in parallel for a univariate truncated normal distribution.
<img src="https://github.com/pytorch/botorch/assets/37524685/a935b106-5376-4339-bf9a-fd03ee02e46a" width=40% />
```
torch.manual_seed(0)
lb, ub = -1, 3
# domain is lb <= x <= ub
A = torch.tensor([[-1.], [1.]], device=device)
b = torch.tensor([-lb, ub], device=device)
x = torch.zeros(1, device=device) + 0.5 * (lb + ub)
sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b.unsqueeze(-1)),
interior_point=x,
check_feasibility=False,
burnin=500,
thinning=0,
chains=500,
)
samples = sampler.draw(n=500)
```
## Related PRs
<!--
(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.)
-->
NA. But I am happy to create a new PR (or within this PR) updating the documentation to describe the argument `chains`.
Reviewed By: sdaulton
Differential Revision: D59639608
Pulled By: esantorella
fbshipit-source-id: 1da29fd27d89e26a46d1b7429742817c4f1d234e
0 commit comments