Skip to content

Commit 0623e8a

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
bvn, MVNXPB, TruncatedMultivariateNormal, and UnifiedSkewNormal (#1394)
Summary: Pull Request resolved: #1394 Introduces `utils/probability` submodule with the following offers: - `bvn`: Methods for computing bivariate normal probabilities and moments. - `MVNXPB`: Approximate solver for Multivariate Normal CDF. - `LinearEllipticalSliceSampler`: Class for sampling trMVN random variables. - `TruncatedMultivariateNormal`: Truncated multivariate normal Distribution class - `UnifiedSkewNormal`: Unified skew normal Distribution class Reviewed By: Balandat Differential Revision: D39326106 fbshipit-source-id: dd485cedb5d8899e159fa462e103222e55177ab6
1 parent 98503e4 commit 0623e8a

20 files changed

+3504
-0
lines changed

botorch/utils/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from functools import lru_cache
10+
from numbers import Number
11+
from typing import Iterator, Optional, Tuple, Union
12+
13+
import torch
14+
from torch import Tensor
15+
16+
17+
@lru_cache(maxsize=None)
18+
def get_constants(
19+
values: Union[Number, Iterator[Number]],
20+
device: Optional[torch.device] = None,
21+
dtype: Optional[torch.dtype] = None,
22+
) -> Union[Tensor, Tuple[Tensor, ...]]:
23+
r"""Returns scalar-valued Tensors containing each of the given constants.
24+
Used to expedite tensor operations involving scalar arithmetic. Note that
25+
the returned Tensors should not be modified in-place."""
26+
if isinstance(values, Number):
27+
return torch.full((), values, dtype=dtype, device=device)
28+
29+
return tuple(torch.full((), val, dtype=dtype, device=device) for val in values)
30+
31+
32+
def get_constants_like(
33+
values: Union[Number, Iterator[Number]],
34+
ref: Tensor,
35+
) -> Union[Tensor, Iterator[Tensor]]:
36+
return get_constants(values, device=ref.device, dtype=ref.dtype)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from botorch.utils.probability.bvn import bvn, bvnmom
8+
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
9+
from botorch.utils.probability.mvnxpb import MVNXPB
10+
from botorch.utils.probability.truncated_multivariate_normal import (
11+
TruncatedMultivariateNormal,
12+
)
13+
from botorch.utils.probability.unified_skew_normal import UnifiedSkewNormal
14+
from botorch.utils.probability.utils import ndtr
15+
16+
17+
__all__ = [
18+
"bvn",
19+
"bvnmom",
20+
"LinearEllipticalSliceSampler",
21+
"MVNXPB",
22+
"ndtr",
23+
"TruncatedMultivariateNormal",
24+
"UnifiedSkewNormal",
25+
]

botorch/utils/probability/bvn.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Methods for computing bivariate normal probabilities and statistics.
9+
10+
.. [Genz2004bvnt]
11+
A. Genz. Numerical computation of rectangular bivariate and trivariate normal and
12+
t probabilities. Statistics and Computing, 2004.
13+
14+
.. [Muthen1990moments]
15+
B. Muthen. Moments of the censored and truncated bivariate normal distribution.
16+
British Journal of Mathematical and Statistical Psychology, 1990.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
from math import pi as _pi
22+
from typing import Optional, Tuple
23+
24+
import torch
25+
from botorch.exceptions import UnsupportedError
26+
from botorch.utils.probability.utils import (
27+
case_dispatcher,
28+
get_constants_like,
29+
leggauss,
30+
ndtr as Phi,
31+
phi,
32+
STANDARDIZED_RANGE,
33+
)
34+
from botorch.utils.safe_math import (
35+
div as safe_div,
36+
exp as safe_exp,
37+
mul as safe_mul,
38+
sub as safe_sub,
39+
)
40+
from torch import Tensor
41+
42+
# Some useful constants
43+
_inf = float("inf")
44+
_2pi = 2 * _pi
45+
_sqrt_2pi = _2pi**0.5
46+
_inv_2pi = 1 / _2pi
47+
48+
49+
def bvn(r: Tensor, xl: Tensor, yl: Tensor, xu: Tensor, yu: Tensor) -> Tensor:
50+
r"""A function for computing bivariate normal probabilities.
51+
52+
Calculates `P(xl < x < xu, yl < y < yu)` where `x` and `y` are bivariate normal with
53+
unit variance and correlation coefficient `r`. See Section 2.4 of [Genz2004bvnt]_.
54+
55+
This method uses a sign flip trick to improve numerical performance. Many of `bvnu`s
56+
internal branches rely on evaluations `Phi(-bound)`. For `a < b < 0`, the term
57+
`Phi(-a) - Phi(-b)` goes to zero faster than `Phi(b) - Phi(a)` because
58+
`finfo(dtype).epsneg` is typically much larger than `finfo(dtype).tiny`. In these
59+
cases, flipping the sign can prevent situations where `bvnu(...) - bvnu(...)` would
60+
otherwise be zero due to round-off error.
61+
62+
Args:
63+
r: Tensor of correlation coefficients.
64+
xl: Tensor of lower bounds for `x`, same shape as `r`.
65+
yl: Tensor of lower bounds for `y`, same shape as `r`.
66+
xu: Tensor of upper bounds for `x`, same shape as `r`.
67+
yu: Tensor of upper bounds for `y`, same shape as `r`.
68+
69+
Returns:
70+
Tensor of probabilities `P(xl < x < xu, yl < y < yu)`.
71+
72+
"""
73+
if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape):
74+
raise UnsupportedError("Arguments to `bvn` must have the same shape.")
75+
76+
# Sign flip trick
77+
_0, _1, _2 = get_constants_like(values=(0, 1, 2), ref=r)
78+
flip_x = xl.abs() > xu # is xl more negative than xu is positive?
79+
flip_y = yl.abs() > yu
80+
flip = (flip_x & (~flip_y | yu.isinf())) | (flip_y & (~flip_x | xu.isinf()))
81+
if flip.any(): # symmetric calls to `bvnu` below makes swapping bounds unnecessary
82+
sign = _1 - _2 * flip.to(dtype=r.dtype)
83+
xl = sign * xl # becomes `-xu` if flipped
84+
xu = sign * xu # becomes `-xl`
85+
yl = sign * yl # becomes `-yu`
86+
yu = sign * yu # becomes `-yl`
87+
88+
p = bvnu(r, xl, yl) - bvnu(r, xu, yl) - bvnu(r, xl, yu) + bvnu(r, xu, yu)
89+
return p.clip(_0, _1)
90+
91+
92+
def bvnu(r: Tensor, h: Tensor, k: Tensor) -> Tensor:
93+
r"""Solves for `P(x > h, y > k)` where `x` and `y` are standard bivariate normal
94+
random variables with correlation coefficient `r`. In [Genz2004bvnt]_, this is (1)
95+
```
96+
L(h, k, r) = P(x < -h, y < -k)
97+
= 1/(a 2\pi) \int_{h}^{\infty} \int_{k}^{\infty} f(x, y, r) dy dx,
98+
```
99+
where `f(x, y, r) = e^{-1/(2a^2) (x^2 - 2rxy + y^2)}` and `a = (1 - r^2)^{1/2}`.
100+
101+
[Genz2004bvnt]_ report the following integation scheme incurs a maximum of 5e-16
102+
error when run in double precision: if |r| >= 0.925, use a 20-point quadrature rule
103+
on a 5th order Taylor expansion; else, numerically integrate in polar coordinates
104+
using no more than 20 quadrature points.
105+
106+
Args:
107+
r: Tensor of correlation coefficients.
108+
h: Tensor of negative upper bounds for `x`, same shape as `r`.
109+
k: Tensor of negative upper bounds for `y`, same shape as `r`.
110+
111+
Returns:
112+
A tensor of probabilities `P(x > h, y > k)`.
113+
"""
114+
if not (r.shape == h.shape == k.shape):
115+
raise UnsupportedError("Arguments to `bvnu` must have the same shape.")
116+
_0, _1, lower, upper = get_constants_like((0, 1) + STANDARDIZED_RANGE, r)
117+
x_free = h < lower
118+
y_free = k < lower
119+
return case_dispatcher(
120+
out=torch.empty_like(r),
121+
cases=( # Special cases admitting closed-form solutions
122+
(lambda: (h > upper) | (k > upper), lambda mask: _0),
123+
(lambda: x_free & y_free, lambda mask: _1),
124+
(lambda: x_free, lambda mask: Phi(-k[mask])),
125+
(lambda: y_free, lambda mask: Phi(-h[mask])),
126+
(lambda: r == _0, lambda mask: Phi(-h[mask]) * Phi(-k[mask])),
127+
( # For |r| >= 0.925, use a Taylor approximation
128+
lambda: r.abs() >= get_constants_like(0.925, r),
129+
lambda m: _bvnu_taylor(r[m], h[m], k[m]),
130+
),
131+
), # For |r| < 0.925, integrate in polar coordinates.
132+
default=lambda mask: _bvnu_polar(r[mask], h[mask], k[mask]),
133+
)
134+
135+
136+
def _bvnu_polar(
137+
r: Tensor, h: Tensor, k: Tensor, num_points: Optional[int] = None
138+
) -> Tensor:
139+
r"""Solves for `P(x > h, y > k)` by integrating in polar coordinates as
140+
```
141+
L(h, k, r) = \Phi(-h)\Phi(-k) + 1/(2\pi) \int_{0}^{sin^{-1}(r)} f(t) dt
142+
f(t) = e^{-0.5 cos(t)^{-2} (h^2 + k^2 - 2hk sin(t))}
143+
```
144+
For details, see Section 2.2 of [Genz2004bvnt]_.
145+
"""
146+
if num_points is None:
147+
mar = r.abs().max()
148+
num_points = 6 if mar < 0.3 else 12 if mar < 0.75 else 20
149+
150+
_0, _1, _i2, _i2pi = get_constants_like(values=(0, 1, 0.5, _inv_2pi), ref=r)
151+
152+
x, w = leggauss(num_points, dtype=r.dtype, device=r.device)
153+
x = x + _1
154+
asin_r = _i2 * torch.asin(r)
155+
sin_asrx = (asin_r.unsqueeze(-1) * x).sin()
156+
157+
_h = h.unsqueeze(-1)
158+
_k = k.unsqueeze(-1)
159+
vals = safe_exp(
160+
safe_sub(safe_mul(sin_asrx, _h * _k), _i2 * (_h.square() + _k.square()))
161+
/ (_1 - sin_asrx.square())
162+
)
163+
probs = Phi(-h) * Phi(-k) + _i2pi * asin_r * (vals @ w)
164+
return probs.clip(min=_0, max=_1) # necessary due to "safe" handling of inf
165+
166+
167+
def _bvnu_taylor(r: Tensor, h: Tensor, k: Tensor, num_points: int = 20) -> Tensor:
168+
r"""Solves for `P(x > h, y > k)` via Taylor expansion.
169+
170+
Per Section 2.3 of [Genz2004bvnt]_, the bvnu equation (1) may be rewritten as
171+
```
172+
L(h, k, r) = L(h, k, s) - s/(2\pi) \int_{0}^{a} f(x) dx
173+
f(x) = (1 - x^2){-1/2} e^{-0.5 ((h - sk)/ x)^2} e^{-shk/(1 + (1 - x^2)^{1/2})},
174+
```
175+
where `s = sign(r)` and `a = sqrt(1 - r^{2})`. The term `L(h, k, s)` is analytic.
176+
The second integral is approximated via Taylor expansion.
177+
"""
178+
_0, _1, _ni2, _i2pi, _sq2pi = get_constants_like(
179+
values=(0, 1, -0.5, _inv_2pi, _sqrt_2pi), ref=r
180+
)
181+
182+
x, w = leggauss(num_points, dtype=r.dtype, device=r.device)
183+
x = x + _1
184+
185+
s = get_constants_like(2, r) * (r > _0).to(r) - _1 # sign of `r` where sign(0) := 1
186+
sk = s * k
187+
skh = sk * h
188+
comp_r2 = _1 - r.square()
189+
190+
a = comp_r2.clip(min=0).sqrt()
191+
b = safe_sub(h, sk)
192+
b2 = b.square()
193+
c = get_constants_like(1 / 8, r) * (get_constants_like(4, r) - skh)
194+
d = get_constants_like(1 / 80, r) * (get_constants_like(12, r) - skh)
195+
196+
# ---- Solve for `L(h, k, s)`
197+
int_from_0_to_s = case_dispatcher(
198+
out=torch.empty_like(r),
199+
cases=[(lambda: r > _0, lambda mask: Phi(-torch.maximum(h[mask], k[mask])))],
200+
default=lambda mask: (Phi(sk[mask]) - Phi(h[mask])).clip(min=_0),
201+
)
202+
203+
# ---- Solve for `s/(2\pi) \int_{0}^{a} f(x) dx`
204+
# Analytic part
205+
_a0 = _ni2 * (safe_div(b2, comp_r2) + skh)
206+
_a1 = c * get_constants_like(1 / 3, r) * (_1 - d * b2)
207+
_a2 = _1 - b2 * _a1
208+
abs_b = b.abs()
209+
analytic_part = torch.subtract( # analytic part of solution
210+
a * (_a2 + comp_r2 * _a1 + c * d * comp_r2.square()) * safe_exp(_a0),
211+
_sq2pi * Phi(safe_div(-abs_b, a)) * abs_b * _a2 * safe_exp(_ni2 * skh),
212+
)
213+
214+
# Quadrature part
215+
_b2 = b2.unsqueeze(-1)
216+
_skh = skh.unsqueeze(-1)
217+
_q0 = get_constants_like(0.25, r) * comp_r2.unsqueeze(-1) * x.square()
218+
_q1 = (_1 - _q0).sqrt()
219+
_q2 = _ni2 * (_b2 / _q0 + _skh)
220+
221+
_b2 = b2.unsqueeze(-1)
222+
_c = c.unsqueeze(-1)
223+
_d = d.unsqueeze(-1)
224+
vals = (_ni2 * (_b2 / _q0 + _skh)).exp() * torch.subtract(
225+
_1 + _c * _q0 * (_1 + get_constants_like(5, r) * _d * _q0),
226+
safe_exp(_ni2 * _q0 / (_1 + _q1).square() * _skh) / _q1,
227+
)
228+
mask = _q2 > get_constants_like(-100, r)
229+
if not mask.all():
230+
vals[~mask] = _0
231+
quadrature_part = _ni2 * a * (vals @ w)
232+
233+
# Return `P(x > h, y > k)`
234+
int_from_0_to_a = _i2pi * s * (analytic_part + quadrature_part)
235+
return (int_from_0_to_s - int_from_0_to_a).clip(min=_0, max=_1)
236+
237+
238+
def bvnmom(
239+
r: Tensor,
240+
xl: Tensor,
241+
yl: Tensor,
242+
xu: Tensor,
243+
yu: Tensor,
244+
p: Optional[Tensor] = None,
245+
) -> Tuple[Tensor, Tensor]:
246+
r"""Computes the expected values of truncated, bivariate normal random variables.
247+
248+
Let `x` and `y` be a pair of standard bivariate normal random variables having
249+
correlation `r`. This function computes `E([x,y] | [xl,yl] < [x,y] < [xu,yu])`.
250+
251+
Following [Muthen1990moments]_ equations (4) and (5), we have
252+
```
253+
E(x | [xl, yl] < [x, y] < [xu, yu])
254+
= Z^{-1} \phi(xl) P(yl < y < yu | x=xl) - \phi(xu) P(yl < y < yu | x=xu)
255+
```
256+
where `Z = P([xl, yl] < [x, y] < [xu, yu])` and `\phi` is the standard normal PDF.
257+
258+
Args:
259+
r: Tensor of correlation coefficients.
260+
xl: Tensor of lower bounds for `x`, same shape as `r`.
261+
xu: Tensor of upper bounds for `x`, same shape as `r`.
262+
yl: Tensor of lower bounds for `y`, same shape as `r`.
263+
yu: Tensor of upper bounds for `y`, same shape as `r`.
264+
p: Tensor of probabilities `P(xl < x < xu, yl < y < yu)`, same shape as `r`.
265+
266+
Returns:
267+
`E(x | [xl, yl] < [x, y] < [xu, yu])` and `E(y | [xl, yl] < [x, y] < [xu, yu])`.
268+
"""
269+
if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape):
270+
raise UnsupportedError("Arguments to `bvn` must have the same shape.")
271+
272+
if p is None:
273+
p = bvn(r=r, xl=xl, xu=xu, yl=yl, yu=yu)
274+
275+
corr = r[..., None, None]
276+
istd = (1 - corr.square()).rsqrt()
277+
lower = torch.stack([xl, yl], -1)
278+
upper = torch.stack([xu, yu], -1)
279+
bounds = torch.stack([lower, upper], -1)
280+
deltas = safe_mul(corr, bounds)
281+
282+
# Compute densities and conditional probabilities
283+
density_at_bounds = phi(bounds)
284+
prob_given_bounds = Phi(
285+
safe_mul(istd, safe_sub(upper.flip(-1).unsqueeze(-1), deltas))
286+
) - Phi(safe_mul(istd, safe_sub(lower.flip(-1).unsqueeze(-1), deltas)))
287+
288+
# Evaluate Muthen's formula
289+
p_diffs = -(density_at_bounds * prob_given_bounds).diff().squeeze(-1)
290+
moments = (1 / p).unsqueeze(-1) * (p_diffs + r.unsqueeze(-1) * p_diffs.flip(-1))
291+
return moments.unbind(-1)

0 commit comments

Comments
 (0)