Skip to content

Commit 5b072ce

Browse files
kayweenfacebook-github-bot
authored andcommitted
An improved elliptical slice sampling implementation (#2426)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Elliptical slice sampling for truncated normal distributions (e.g., [Gessner et al., 2020](https://arxiv.org/abs/1910.09328)) requires constructing the active intervals corresponding to the intersection of the ellipse and the domain. One method constructing the active intervals is based on likelihood testing, which tests whether each intersection angle is active or not. The current BoTorch implementation follows this idea and takes $\mathcal{O}(m^2 d)$ time per iteration, where $d$ is the dimensionality and $m$ is the number of linear inequality constraints. However, there exists a faster algorithm computing the active intervals in $\mathcal{O}(m \log m)$ time ([Wu and Gardner, 2024](https://arxiv.org/abs/2407.10449)). This PR implements Algorithm 3 in this paper and dramatically accelerates truncated normal sampling in high dimensions. In addition, this PR also implements batch MCMC to launch multiple Markov chains, which further speeds up sampling by exploiting GPU parallelism better. Users can pass an additional argument `chains` to specify how many chains to run in parallel: ``` batch_sampler = LinearEllipticalSliceSampler( inequality_constraints=(A, b), check_feasibility=True, chains=100, # launch 100 Markov chains in parallel ) samples = batch_samplers.draw(n) # returns 100 * n samples ``` ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #2426 Test Plan: <!-- (Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!) --> 1. This implementation passes all existing test cases in `test/utils/probability/test_lin_ess.py` 2. Though, two test cases [here](https://github.com/pytorch/botorch/blob/0455dc3945c5e89eefa2dfeb7abab8c4f4079b36/test/utils/probability/test_lin_ess.py#L363-L396) have to be modified. Because the function `self._find_active_intersections` is replaced with `self._find_active_intersection_angles`, which has a different output format. 3. I have added additional test cases for batch MCMC. ### Some Plots Attached is a plot demonstrating that this new algorithm accelerates truncated normal sampling in high dimensions. This experiments is ran with `torch.float32`. I have experienced a similar speed-up with double precision. <p> <img src="https://github.com/pytorch/botorch/assets/37524685/9c86d293-69aa-4911-ad7e-9c4c0ead655f" width=49% /> <img src="https://github.com/pytorch/botorch/assets/37524685/004f5fc8-6b88-462d-9a34-7d20b7e59a32" width=49% /> </p> The following is a plot of samples by running 500 Markov chains in parallel for a univariate truncated normal distribution. <img src="https://github.com/pytorch/botorch/assets/37524685/a935b106-5376-4339-bf9a-fd03ee02e46a" width=40% /> ``` torch.manual_seed(0) lb, ub = -1, 3 # domain is lb <= x <= ub A = torch.tensor([[-1.], [1.]], device=device) b = torch.tensor([-lb, ub], device=device) x = torch.zeros(1, device=device) + 0.5 * (lb + ub) sampler = LinearEllipticalSliceSampler( inequality_constraints=(A, b.unsqueeze(-1)), interior_point=x, check_feasibility=False, burnin=500, thinning=0, chains=500, ) samples = sampler.draw(n=500) ``` ## Related PRs <!-- (If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.) --> NA. But I am happy to create a new PR (or within this PR) updating the documentation to describe the argument `chains`. Reviewed By: sdaulton Differential Revision: D59639608 Pulled By: esantorella fbshipit-source-id: 1da29fd27d89e26a46d1b7429742817c4f1d234e
1 parent e7915b1 commit 5b072ce

File tree

2 files changed

+165
-150
lines changed

2 files changed

+165
-150
lines changed

botorch/utils/probability/lin_ess.py

Lines changed: 110 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@
1212
A. Gessner, O. Kanjilal, and P. Hennig. Integrals over gaussians under
1313
linear domain constraints. AISTATS 2020.
1414
15+
.. [Wu2024]
16+
K. Wu, and J. Gardner. A Fast, Robust Elliptical Slice Sampling Implementation for
17+
Linearly Truncated Multivariate Normal Distributions. arXiv:2407.10449. 2024.
1518
1619
This implementation is based (with multiple changes / optimiations) on
1720
the following implementations based on the algorithm in [Gessner2020]_:
1821
- https://github.com/alpiges/LinConGauss
1922
- https://github.com/wjmaddox/pytorch_ess
2023
24+
In addition, the active intervals (from which the angle is sampled) are computed using
25+
the improved algorithm described in [Wu2024]_:
26+
https://github.com/kayween/linear-ess
27+
2128
The implementation here differentiates itself from the original implementations with:
2229
1) Support for fixed feature equality constraints.
2330
2) Support for non-standard Normal distributions.
2431
3) Numerical stability improvements, especially relevant for high-dimensional cases.
25-
26-
Notably, this implementation does not rely on an adaptive `delta_theta` parameter in
27-
order to determine if two neighboring constraint intersection angles `theta` lead to a
28-
change in the feasibility of the sample. This both simplifies the implementation and
29-
makes it more robust to numerical imprecisions when two constraint intersection angles
30-
are close to each other.
32+
4) Support multiple Markov chains running in parallel.
3133
"""
3234

3335
from __future__ import annotations
@@ -47,7 +49,6 @@ class LinearEllipticalSliceSampler(PolytopeSampler):
4749
r"""Linear Elliptical Slice Sampler.
4850
4951
Ideas:
50-
- Add batch support, broadcasting over parallel chains.
5152
- Optimize computations if possible, potentially with torch.compile.
5253
- Extend fixed features constraint to general linear equality constraints.
5354
"""
@@ -64,6 +65,7 @@ def __init__(
6465
check_feasibility: bool = False,
6566
burnin: int = 0,
6667
thinning: int = 0,
68+
num_chains: int = 1,
6769
) -> None:
6870
r"""Initialize LinearEllipticalSliceSampler.
6971
@@ -99,6 +101,7 @@ def __init__(
99101
burnin: Number of samples to generate upon initialization to warm up the
100102
sampler.
101103
thinning: Number of samples to skip before returning a sample in `draw`.
104+
num_chains: Number of Markov chains to run in parallel.
102105
103106
This sampler samples from a multivariante Normal `N(mean, covariance_matrix)`
104107
subject to linear domain constraints `A x <= b` (intersected with box bounds,
@@ -158,10 +161,17 @@ def __init__(
158161
self._x = self.x0.clone()
159162
self._z = self._transform(self._x)
160163

164+
# Expand the shape to (d, num_chains) for running parallel Markov chains.
165+
if num_chains > 1:
166+
self._z = self._z.expand(-1, num_chains).clone()
167+
161168
# We will need the following repeatedly, let's allocate them once
162-
self._zero = torch.zeros(1, **tkwargs)
163-
self._nan = torch.tensor(float("nan"), **tkwargs)
164-
self._full_angular_range = torch.tensor([0.0, _twopi], **tkwargs)
169+
self.zeros = torch.zeros((num_chains, 1), **tkwargs)
170+
self.ones = torch.ones((num_chains, 1), **tkwargs)
171+
self.indices_batch = torch.arange(
172+
num_chains, dtype=torch.int64, device=tkwargs["device"]
173+
)
174+
165175
self.check_feasibility = check_feasibility
166176
self._lifetime_samples = 0
167177
if burnin > 0:
@@ -245,14 +255,14 @@ def lifetime_samples(self) -> int:
245255
"""The total number of samples generated by the sampler during its lifetime."""
246256
return self._lifetime_samples
247257

248-
def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]:
258+
def draw(self, n: int = 1) -> Tensor:
249259
r"""Draw samples.
250260
251261
Args:
252262
n: The number of samples.
253263
254264
Returns:
255-
A `n x d`-dim tensor of `n` samples.
265+
A `(n * num_chains) x d`-dim tensor of `n * num_chains` samples.
256266
"""
257267
samples = []
258268
for _ in range(n):
@@ -265,16 +275,17 @@ def step(self) -> Tensor:
265275
r"""Take a step, return the new sample, update the internal state.
266276
267277
Returns:
268-
A `d x 1`-dim sample from the domain.
278+
A `d x num_chains`-dim tensor, where each column is a sample from a Markov
279+
chain.
269280
"""
270281
nu = torch.randn_like(self._z)
271282
theta = self._draw_angle(nu=nu)
272-
z = self._get_cart_coords(nu=nu, theta=theta)
273-
self._z[:] = z
274-
x = self._untransform(z)
275-
self._x[:] = x
283+
284+
self._z = z = self._get_cart_coords(nu=nu, theta=theta)
285+
self._x = x = self._untransform(z)
286+
276287
self._lifetime_samples += 1
277-
if self.check_feasibility and (not self._is_feasible(self._x)):
288+
if self.check_feasibility and (not self._is_feasible(self._x).all()):
278289
Axmb = self.A @ self._x - self.b
279290
violated_indices = Axmb > 0
280291
raise RuntimeError(
@@ -289,157 +300,119 @@ def _draw_angle(self, nu: Tensor) -> Tensor:
289300
r"""Draw the rotation angle.
290301
291302
Args:
292-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
303+
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
293304
294305
Returns:
295-
A `1`-dim Tensor containing the rotation angle (radians).
306+
A `num_chains`-dim Tensor containing the rotation angle (radians).
296307
"""
297-
rot_angle, rot_slices = self._find_rotated_intersections(nu)
298-
rot_lengths = rot_slices[:, 1] - rot_slices[:, 0]
299-
cum_lengths = torch.cumsum(rot_lengths, dim=0)
300-
cum_lengths = torch.cat((self._zero, cum_lengths), dim=0)
301-
rnd_angle = cum_lengths[-1] * torch.rand(
302-
1, device=cum_lengths.device, dtype=cum_lengths.dtype
308+
left, right = self._find_active_intersection_angles(nu)
309+
left, right = self._trim_intervals(left, right)
310+
311+
# If left[i, j] <= right[i, j], then [left[i, j], right[i, j]] is an active
312+
# interval. On the other hand, if left[i, j] > right[i, j], then they are both
313+
# dummy variables and should be discarded. Thus, we clamp their difference so
314+
# that they do not contribute to the cumulative length.
315+
csum = right.sub(left).clamp(min=0.0).cumsum(dim=-1)
316+
317+
u = csum[:, -1] * torch.rand(
318+
right.size(-2), dtype=right.dtype, device=right.device
303319
)
304-
idx = torch.searchsorted(cum_lengths, rnd_angle) - 1
305-
return (rot_slices[idx, 0] + rnd_angle + rot_angle) - cum_lengths[idx]
320+
321+
# The returned index i satisfies csum[i - 1] < u <= csum[i]
322+
idx = torch.searchsorted(csum, u.unsqueeze(-1)).squeeze(-1)
323+
324+
# Do a zero padding so that padded_csum[i] = csum[i - 1]
325+
padded_csum = torch.cat([self.zeros, csum], dim=-1)
326+
327+
return u - padded_csum[self.indices_batch, idx] + left[self.indices_batch, idx]
306328

307329
def _get_cart_coords(self, nu: Tensor, theta: Tensor) -> Tensor:
308-
r"""Determine location on ellipsoid in cartesian coordinates.
330+
r"""Determine location on the ellipse in Cartesian coordinates.
309331
310332
Args:
311-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
312-
theta: A `k`-dim tensor of angles.
333+
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
334+
theta: A `num_chains`-dim tensor of angles.
313335
314336
Returns:
315-
A `d x k`-dim tensor of samples from the domain in cartesian coordinates.
337+
A `d x num_chains`-dim tensor of samples from the domain in Cartesian
338+
coordinates.
316339
"""
317340
return self._z * torch.cos(theta) + nu * torch.sin(theta)
318341

319-
def _find_rotated_intersections(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
320-
r"""Finds rotated intersections.
342+
def _trim_intervals(self, left: Tensor, right: Tensor) -> Tuple[Tensor, Tensor]:
343+
"""Trim the intervals by a small positive constant. This encourages the Markov
344+
chain to stay in the interior of the domain.
345+
"""
346+
gap = torch.clamp(right - left, min=0.0)
347+
eps = gap.mul(0.25).clamp(max=1e-6 if gap.dtype == torch.float32 else 1e-12)
348+
349+
return left + eps, right - eps
321350

322-
Rotates the intersections by the rotation angle and makes sure that all
323-
angles lie in [0, 2*pi].
351+
def _find_active_intersection_angles(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
352+
"""Construct the active intersection angles.
324353
325354
Args:
326-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
355+
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
327356
328357
Returns:
329-
A two-tuple containing rotation angle (scalar) and a
330-
`num_active / 2 x 2`-dim tensor of shifted angles.
358+
A tuple (left, right) of two tensors of size `num_chains x m` representing
359+
the active intersection angles. For the i-th Markov chain and the j-th
360+
constraint, a pair of angles left[i, j] and right[i, j] is active if and
361+
only if left[i, j] <= right[i, j]. If left[i, j] > right[i, j], they are
362+
inactive and should be ignored.
331363
"""
332-
slices = self._find_active_intersections(nu)
333-
rot_angle = slices[0]
334-
slices = (slices - rot_angle).reshape(-1, 2)
335-
# Ensuring that we don't sample within numerical precision of the boundaries
336-
# due to resulting instabilities in the constraint satisfaction.
337-
eps = 1e-6 if slices.dtype == torch.float32 else 1e-12
338-
eps = torch.tensor(eps, dtype=slices.dtype, device=slices.device)
339-
eps = eps.minimum(slices.diff(dim=-1).abs() / 4)
340-
slices = slices + torch.cat((eps, -eps), dim=-1)
341-
# NOTE: The remainder call relies on the epsilon contraction, since the
342-
# remainder of_twopi divided by _twopi is zero, not _twopi.
343-
return rot_angle, slices.remainder(_twopi)
344-
345-
def _find_active_intersections(self, nu: Tensor) -> Tensor:
346-
"""
347-
Find angles of those intersections that are at the boundary of the integration
348-
domain by adding and subtracting a small angle and evaluating on the ellipse
349-
to see if we are on the boundary of the integration domain.
364+
alpha, beta = self._find_intersection_angles(nu)
365+
366+
# It's easier to put `num_chains` as the first dimension,
367+
# because `torch.searchsorted` only supports searching in the last dimension
368+
alpha, beta = alpha.T, beta.T
369+
370+
srted, indices = torch.sort(alpha, descending=False)
371+
cummax = beta[self.indices_batch.unsqueeze(-1), indices].cummax(dim=-1).values
372+
373+
srted = torch.cat([srted, self.ones * 2 * math.pi], dim=-1)
374+
cummax = torch.cat([self.zeros, cummax], dim=-1)
375+
376+
return cummax, srted
377+
378+
def _find_intersection_angles(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
379+
"""Compute all 2 * m intersections of the ellipse and the domain, where
380+
`m = n_ineq_con` is the number of inequality constraints defining the domain.
381+
If the i-th linear inequality constraint has no intersection with the ellipse,
382+
we will create two dummy intersection angles alpha_i = beta_i = 0.
350383
351384
Args:
352-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
385+
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
353386
354387
Returns:
355-
A `num_active`-dim tensor containing the angles of active intersection in
356-
increasing order so that activation happens in positive direction. If a
357-
slice crosses `theta=0`, the first angle is appended at the end of the
358-
tensor. Every element of the returned tensor defines a slice for elliptical
359-
slice sampling.
388+
A tuple of two tensors with the same size `m x num_chains`. The first tensor
389+
represents the smaller intersection angles. The second tensor represents the
390+
larger intersection angles.
360391
"""
361-
theta = self._find_intersection_angles(nu)
362-
theta_active, delta_active = self._active_theta_and_delta(
363-
nu=nu,
364-
theta=theta,
365-
)
366-
if theta_active.numel() == 0:
367-
theta_active = self._full_angular_range
368-
# TODO: What about `self.ellipse_in_domain = False` in the original code?
369-
elif delta_active[0] == -1: # ensuring that the first interval is feasible
392+
p = self._Az @ self._z
393+
q = self._Az @ nu
370394

371-
theta_active = torch.cat((theta_active[1:], theta_active[:1]))
395+
radius = torch.sqrt(p**2 + q**2)
372396

373-
return theta_active.view(-1)
397+
ratio = self._bz / radius
374398

375-
def _find_intersection_angles(self, nu: Tensor) -> Tensor:
376-
"""Compute all of the up to 2*n_ineq_con intersections of the ellipse
377-
and the linear constraints.
399+
has_solution = ratio < 1.0
378400

379-
For background, see equation (2) in
380-
http://proceedings.mlr.press/v108/gessner20a/gessner20a.pdf
401+
arccos = torch.arccos(ratio)
402+
arccos[~has_solution] = 0.0
403+
arctan = torch.arctan2(q, p)
381404

382-
Args:
383-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
405+
theta1 = arctan + arccos
406+
theta2 = arctan - arccos
384407

385-
Returns:
386-
An `M`-dim tensor, where `M <= 2 * n_ineq_con` (with `M = n_ineq_con`
387-
if all intermediate computations yield finite numbers).
388-
"""
389-
# Compared to the implementation in https://github.com/alpiges/LinConGauss
390-
# we need to flip the sign of A b/c the original algorithm considers
391-
# A @ x + b >= 0 feasible, whereas we consider A @ x - b <= 0 feasible.
392-
g1 = -self._Az @ self._z
393-
g2 = -self._Az @ nu
394-
r = torch.sqrt(g1**2 + g2**2)
395-
phi = 2 * torch.atan(g2 / (r + g1)).squeeze()
396-
397-
arg = -(self._bz / r).squeeze()
398-
# Write NaNs if there is no intersection
399-
arg = torch.where(torch.absolute(arg) <= 1, arg, self._nan)
400-
# Two solutions per linear constraint, shape of theta: (n_ineq_con, 2)
401-
acos_arg = torch.arccos(arg)
402-
theta = torch.stack((phi + acos_arg, phi - acos_arg), dim=-1)
403-
theta = theta[torch.isfinite(theta)] # shape: `n_ineq_con - num_not_finite`
404-
theta = torch.where(theta < 0, theta + _twopi, theta) # in [0, 2*pi]
405-
return torch.sort(theta).values
406-
407-
def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor:
408-
r"""Determine active indices.
408+
# translate every angle to [0, 2 * pi]
409+
theta1 = theta1 + theta1.lt(0.0) * _twopi
410+
theta2 = theta2 + theta2.lt(0.0) * _twopi
409411

410-
Args:
411-
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
412-
theta: A sorted `M`-dim tensor of intersection angles in [0, 2pi].
412+
alpha = torch.minimum(theta1, theta2)
413+
beta = torch.maximum(theta1, theta2)
413414

414-
Returns:
415-
A tuple of Tensors of active constraint intersection angles `theta_active`,
416-
and the change in the feasibility of the points on the ellipse on the left
417-
and right of the active intersection angles `delta_active`. `delta_active`
418-
is is negative if decreasing the angle renders the sample feasible, and
419-
positive if increasing the angle renders the sample feasible.
420-
"""
421-
# In order to determine if an angle that gives rise to an intersection with a
422-
# constraint boundary leads to a change in the feasibility of the solution,
423-
# we evaluate the constraints on the midpoint of the intersection angles.
424-
# This gets rid of the `delta_theta` parameter in the original implementation,
425-
# which cannot be set universally since it can be both 1) too large, when
426-
# the distance in adjacent intersection angles is small, and 2) too small,
427-
# when it approaches the numerical precision limit.
428-
# The implementation below solves both problems and gets rid of the parameter.
429-
if len(theta) < 2: # if we have no or only a tangential intersection
430-
theta_active = torch.tensor([], dtype=theta.dtype, device=theta.device)
431-
delta_active = torch.tensor([], dtype=int, device=theta.device)
432-
return theta_active, delta_active
433-
theta_mid = (theta[:-1] + theta[1:]) / 2 # midpoints of intersection angles
434-
last_mid = (theta[:1] + theta[-1:] + _twopi) / 2
435-
last_mid = last_mid.where(last_mid < _twopi, last_mid - _twopi)
436-
theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0)
437-
samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid)
438-
delta_feasibility = (
439-
self._is_feasible(samples_mid, transformed=True).to(dtype=int).diff()
440-
)
441-
active_indices = delta_feasibility.nonzero()
442-
return theta[active_indices], delta_feasibility[active_indices]
415+
return alpha, beta
443416

444417
def _is_feasible(self, points: Tensor, transformed: bool = False) -> Tensor:
445418
r"""Returns a Boolean tensor indicating whether the `points` are feasible,

0 commit comments

Comments
 (0)