|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +r""" |
| 8 | +Methods for computing bivariate normal probabilities and statistics. |
| 9 | +
|
| 10 | +.. [Genz2004bvnt] |
| 11 | + A. Genz. Numerical computation of rectangular bivariate and trivariate normal and |
| 12 | + t probabilities. Statistics and Computing, 2004. |
| 13 | +
|
| 14 | +.. [Muthen1990moments] |
| 15 | + B. Muthen. Moments of the censored and truncated bivariate normal distribution. |
| 16 | + British Journal of Mathematical and Statistical Psychology, 1990. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +from math import pi as _pi |
| 22 | +from typing import Optional, Tuple |
| 23 | + |
| 24 | +import torch |
| 25 | +from botorch.exceptions import UnsupportedError |
| 26 | +from botorch.utils.probability.utils import ( |
| 27 | + case_dispatcher, |
| 28 | + get_constants_like, |
| 29 | + leggauss, |
| 30 | + ndtr as Phi, |
| 31 | + phi, |
| 32 | + STANDARDIZED_RANGE, |
| 33 | +) |
| 34 | +from botorch.utils.safe_math import ( |
| 35 | + div as safe_div, |
| 36 | + exp as safe_exp, |
| 37 | + mul as safe_mul, |
| 38 | + sub as safe_sub, |
| 39 | +) |
| 40 | +from torch import Tensor |
| 41 | + |
| 42 | +# Some useful constants |
| 43 | +_inf = float("inf") |
| 44 | +_2pi = 2 * _pi |
| 45 | +_sqrt_2pi = _2pi**0.5 |
| 46 | +_inv_2pi = 1 / _2pi |
| 47 | + |
| 48 | + |
| 49 | +def bvn(r: Tensor, xl: Tensor, yl: Tensor, xu: Tensor, yu: Tensor) -> Tensor: |
| 50 | + r"""A function for computing bivariate normal probabilities. |
| 51 | +
|
| 52 | + Calculates `P(xl < x < xu, yl < y < yu)` where `x` and `y` are bivariate normal with |
| 53 | + unit variance and correlation coefficient `r`. See Section 2.4 of [Genz2004bvnt]_. |
| 54 | +
|
| 55 | + This method uses a sign flip trick to improve numerical performance. Many of `bvnu`s |
| 56 | + internal branches rely on evaluations `Phi(-bound)`. For `a < b < 0`, the term |
| 57 | + `Phi(-a) - Phi(-b)` goes to zero faster than `Phi(b) - Phi(a)` because |
| 58 | + `finfo(dtype).epsneg` is typically much larger than `finfo(dtype).tiny`. In these |
| 59 | + cases, flipping the sign can prevent situations where `bvnu(...) - bvnu(...)` would |
| 60 | + otherwise be zero due to round-off error. |
| 61 | +
|
| 62 | + Args: |
| 63 | + r: Tensor of correlation coefficients. |
| 64 | + xl: Tensor of lower bounds for `x`, same shape as `r`. |
| 65 | + yl: Tensor of lower bounds for `y`, same shape as `r`. |
| 66 | + xu: Tensor of upper bounds for `x`, same shape as `r`. |
| 67 | + yu: Tensor of upper bounds for `y`, same shape as `r`. |
| 68 | +
|
| 69 | + Returns: |
| 70 | + Tensor of probabilities `P(xl < x < xu, yl < y < yu)`. |
| 71 | +
|
| 72 | + """ |
| 73 | + if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape): |
| 74 | + raise UnsupportedError("Arguments to `bvn` must have the same shape.") |
| 75 | + |
| 76 | + # Sign flip trick |
| 77 | + _0, _1, _2 = get_constants_like(values=(0, 1, 2), ref=r) |
| 78 | + flip_x = xl.abs() > xu # is xl more negative than xu is positive? |
| 79 | + flip_y = yl.abs() > yu |
| 80 | + flip = (flip_x & (~flip_y | yu.isinf())) | (flip_y & (~flip_x | xu.isinf())) |
| 81 | + if flip.any(): # symmetric calls to `bvnu` below makes swapping bounds unnecessary |
| 82 | + sign = _1 - _2 * flip.to(dtype=r.dtype) |
| 83 | + xl = sign * xl # becomes `-xu` if flipped |
| 84 | + xu = sign * xu # becomes `-xl` |
| 85 | + yl = sign * yl # becomes `-yu` |
| 86 | + yu = sign * yu # becomes `-yl` |
| 87 | + |
| 88 | + p = bvnu(r, xl, yl) - bvnu(r, xu, yl) - bvnu(r, xl, yu) + bvnu(r, xu, yu) |
| 89 | + return p.clip(_0, _1) |
| 90 | + |
| 91 | + |
| 92 | +def bvnu(r: Tensor, h: Tensor, k: Tensor) -> Tensor: |
| 93 | + r"""Solves for `P(x > h, y > k)` where `x` and `y` are standard bivariate normal |
| 94 | + random variables with correlation coefficient `r`. In [Genz2004bvnt]_, this is (1) |
| 95 | + ``` |
| 96 | + L(h, k, r) = P(x < -h, y < -k) |
| 97 | + = 1/(a 2\pi) \int_{h}^{\infty} \int_{k}^{\infty} f(x, y, r) dy dx, |
| 98 | + ``` |
| 99 | + where `f(x, y, r) = e^{-1/(2a^2) (x^2 - 2rxy + y^2)}` and `a = (1 - r^2)^{1/2}`. |
| 100 | +
|
| 101 | + [Genz2004bvnt]_ report the following integation scheme incurs a maximum of 5e-16 |
| 102 | + error when run in double precision: if |r| >= 0.925, use a 20-point quadrature rule |
| 103 | + on a 5th order Taylor expansion; else, numerically integrate in polar coordinates |
| 104 | + using no more than 20 quadrature points. |
| 105 | +
|
| 106 | + Args: |
| 107 | + r: Tensor of correlation coefficients. |
| 108 | + h: Tensor of negative upper bounds for `x`, same shape as `r`. |
| 109 | + k: Tensor of negative upper bounds for `y`, same shape as `r`. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + A tensor of probabilities `P(x > h, y > k)`. |
| 113 | + """ |
| 114 | + if not (r.shape == h.shape == k.shape): |
| 115 | + raise UnsupportedError("Arguments to `bvnu` must have the same shape.") |
| 116 | + _0, _1, lower, upper = get_constants_like((0, 1) + STANDARDIZED_RANGE, r) |
| 117 | + x_free = h < lower |
| 118 | + y_free = k < lower |
| 119 | + return case_dispatcher( |
| 120 | + out=torch.empty_like(r), |
| 121 | + cases=( # Special cases admitting closed-form solutions |
| 122 | + (lambda: (h > upper) | (k > upper), lambda mask: _0), |
| 123 | + (lambda: x_free & y_free, lambda mask: _1), |
| 124 | + (lambda: x_free, lambda mask: Phi(-k[mask])), |
| 125 | + (lambda: y_free, lambda mask: Phi(-h[mask])), |
| 126 | + (lambda: r == _0, lambda mask: Phi(-h[mask]) * Phi(-k[mask])), |
| 127 | + ( # For |r| >= 0.925, use a Taylor approximation |
| 128 | + lambda: r.abs() >= get_constants_like(0.925, r), |
| 129 | + lambda m: _bvnu_taylor(r[m], h[m], k[m]), |
| 130 | + ), |
| 131 | + ), # For |r| < 0.925, integrate in polar coordinates. |
| 132 | + default=lambda mask: _bvnu_polar(r[mask], h[mask], k[mask]), |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +def _bvnu_polar( |
| 137 | + r: Tensor, h: Tensor, k: Tensor, num_points: Optional[int] = None |
| 138 | +) -> Tensor: |
| 139 | + r"""Solves for `P(x > h, y > k)` by integrating in polar coordinates as |
| 140 | + ``` |
| 141 | + L(h, k, r) = \Phi(-h)\Phi(-k) + 1/(2\pi) \int_{0}^{sin^{-1}(r)} f(t) dt |
| 142 | + f(t) = e^{-0.5 cos(t)^{-2} (h^2 + k^2 - 2hk sin(t))} |
| 143 | + ``` |
| 144 | + For details, see Section 2.2 of [Genz2004bvnt]_. |
| 145 | + """ |
| 146 | + if num_points is None: |
| 147 | + mar = r.abs().max() |
| 148 | + num_points = 6 if mar < 0.3 else 12 if mar < 0.75 else 20 |
| 149 | + |
| 150 | + _0, _1, _i2, _i2pi = get_constants_like(values=(0, 1, 0.5, _inv_2pi), ref=r) |
| 151 | + |
| 152 | + x, w = leggauss(num_points, dtype=r.dtype, device=r.device) |
| 153 | + x = x + _1 |
| 154 | + asin_r = _i2 * torch.asin(r) |
| 155 | + sin_asrx = (asin_r.unsqueeze(-1) * x).sin() |
| 156 | + |
| 157 | + _h = h.unsqueeze(-1) |
| 158 | + _k = k.unsqueeze(-1) |
| 159 | + vals = safe_exp( |
| 160 | + safe_sub(safe_mul(sin_asrx, _h * _k), _i2 * (_h.square() + _k.square())) |
| 161 | + / (_1 - sin_asrx.square()) |
| 162 | + ) |
| 163 | + probs = Phi(-h) * Phi(-k) + _i2pi * asin_r * (vals @ w) |
| 164 | + return probs.clip(min=_0, max=_1) # necessary due to "safe" handling of inf |
| 165 | + |
| 166 | + |
| 167 | +def _bvnu_taylor(r: Tensor, h: Tensor, k: Tensor, num_points: int = 20) -> Tensor: |
| 168 | + r"""Solves for `P(x > h, y > k)` via Taylor expansion. |
| 169 | +
|
| 170 | + Per Section 2.3 of [Genz2004bvnt]_, the bvnu equation (1) may be rewritten as |
| 171 | + ``` |
| 172 | + L(h, k, r) = L(h, k, s) - s/(2\pi) \int_{0}^{a} f(x) dx |
| 173 | + f(x) = (1 - x^2){-1/2} e^{-0.5 ((h - sk)/ x)^2} e^{-shk/(1 + (1 - x^2)^{1/2})}, |
| 174 | + ``` |
| 175 | + where `s = sign(r)` and `a = sqrt(1 - r^{2})`. The term `L(h, k, s)` is analytic. |
| 176 | + The second integral is approximated via Taylor expansion. |
| 177 | + """ |
| 178 | + _0, _1, _ni2, _i2pi, _sq2pi = get_constants_like( |
| 179 | + values=(0, 1, -0.5, _inv_2pi, _sqrt_2pi), ref=r |
| 180 | + ) |
| 181 | + |
| 182 | + x, w = leggauss(num_points, dtype=r.dtype, device=r.device) |
| 183 | + x = x + _1 |
| 184 | + |
| 185 | + s = get_constants_like(2, r) * (r > _0).to(r) - _1 # sign of `r` where sign(0) := 1 |
| 186 | + sk = s * k |
| 187 | + skh = sk * h |
| 188 | + comp_r2 = _1 - r.square() |
| 189 | + |
| 190 | + a = comp_r2.clip(min=0).sqrt() |
| 191 | + b = safe_sub(h, sk) |
| 192 | + b2 = b.square() |
| 193 | + c = get_constants_like(1 / 8, r) * (get_constants_like(4, r) - skh) |
| 194 | + d = get_constants_like(1 / 80, r) * (get_constants_like(12, r) - skh) |
| 195 | + |
| 196 | + # ---- Solve for `L(h, k, s)` |
| 197 | + int_from_0_to_s = case_dispatcher( |
| 198 | + out=torch.empty_like(r), |
| 199 | + cases=[(lambda: r > _0, lambda mask: Phi(-torch.maximum(h[mask], k[mask])))], |
| 200 | + default=lambda mask: (Phi(sk[mask]) - Phi(h[mask])).clip(min=_0), |
| 201 | + ) |
| 202 | + |
| 203 | + # ---- Solve for `s/(2\pi) \int_{0}^{a} f(x) dx` |
| 204 | + # Analytic part |
| 205 | + _a0 = _ni2 * (safe_div(b2, comp_r2) + skh) |
| 206 | + _a1 = c * get_constants_like(1 / 3, r) * (_1 - d * b2) |
| 207 | + _a2 = _1 - b2 * _a1 |
| 208 | + abs_b = b.abs() |
| 209 | + analytic_part = torch.subtract( # analytic part of solution |
| 210 | + a * (_a2 + comp_r2 * _a1 + c * d * comp_r2.square()) * safe_exp(_a0), |
| 211 | + _sq2pi * Phi(safe_div(-abs_b, a)) * abs_b * _a2 * safe_exp(_ni2 * skh), |
| 212 | + ) |
| 213 | + |
| 214 | + # Quadrature part |
| 215 | + _b2 = b2.unsqueeze(-1) |
| 216 | + _skh = skh.unsqueeze(-1) |
| 217 | + _q0 = get_constants_like(0.25, r) * comp_r2.unsqueeze(-1) * x.square() |
| 218 | + _q1 = (_1 - _q0).sqrt() |
| 219 | + _q2 = _ni2 * (_b2 / _q0 + _skh) |
| 220 | + |
| 221 | + _b2 = b2.unsqueeze(-1) |
| 222 | + _c = c.unsqueeze(-1) |
| 223 | + _d = d.unsqueeze(-1) |
| 224 | + vals = (_ni2 * (_b2 / _q0 + _skh)).exp() * torch.subtract( |
| 225 | + _1 + _c * _q0 * (_1 + get_constants_like(5, r) * _d * _q0), |
| 226 | + safe_exp(_ni2 * _q0 / (_1 + _q1).square() * _skh) / _q1, |
| 227 | + ) |
| 228 | + mask = _q2 > get_constants_like(-100, r) |
| 229 | + if not mask.all(): |
| 230 | + vals[~mask] = _0 |
| 231 | + quadrature_part = _ni2 * a * (vals @ w) |
| 232 | + |
| 233 | + # Return `P(x > h, y > k)` |
| 234 | + int_from_0_to_a = _i2pi * s * (analytic_part + quadrature_part) |
| 235 | + return (int_from_0_to_s - int_from_0_to_a).clip(min=_0, max=_1) |
| 236 | + |
| 237 | + |
| 238 | +def bvnmom( |
| 239 | + r: Tensor, |
| 240 | + xl: Tensor, |
| 241 | + yl: Tensor, |
| 242 | + xu: Tensor, |
| 243 | + yu: Tensor, |
| 244 | + p: Optional[Tensor] = None, |
| 245 | +) -> Tuple[Tensor, Tensor]: |
| 246 | + r"""Computes the expected values of truncated, bivariate normal random variables. |
| 247 | +
|
| 248 | + Let `x` and `y` be a pair of standard bivariate normal random variables having |
| 249 | + correlation `r`. This function computes `E([x,y] | [xl,yl] < [x,y] < [xu,yu])`. |
| 250 | +
|
| 251 | + Following [Muthen1990moments]_ equations (4) and (5), we have |
| 252 | + ``` |
| 253 | + E(x | [xl, yl] < [x, y] < [xu, yu]) |
| 254 | + = Z^{-1} \phi(xl) P(yl < y < yu | x=xl) - \phi(xu) P(yl < y < yu | x=xu) |
| 255 | + ``` |
| 256 | + where `Z = P([xl, yl] < [x, y] < [xu, yu])` and `\phi` is the standard normal PDF. |
| 257 | +
|
| 258 | + Args: |
| 259 | + r: Tensor of correlation coefficients. |
| 260 | + xl: Tensor of lower bounds for `x`, same shape as `r`. |
| 261 | + xu: Tensor of upper bounds for `x`, same shape as `r`. |
| 262 | + yl: Tensor of lower bounds for `y`, same shape as `r`. |
| 263 | + yu: Tensor of upper bounds for `y`, same shape as `r`. |
| 264 | + p: Tensor of probabilities `P(xl < x < xu, yl < y < yu)`, same shape as `r`. |
| 265 | +
|
| 266 | + Returns: |
| 267 | + `E(x | [xl, yl] < [x, y] < [xu, yu])` and `E(y | [xl, yl] < [x, y] < [xu, yu])`. |
| 268 | + """ |
| 269 | + if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape): |
| 270 | + raise UnsupportedError("Arguments to `bvn` must have the same shape.") |
| 271 | + |
| 272 | + if p is None: |
| 273 | + p = bvn(r=r, xl=xl, xu=xu, yl=yl, yu=yu) |
| 274 | + |
| 275 | + corr = r[..., None, None] |
| 276 | + istd = (1 - corr.square()).rsqrt() |
| 277 | + lower = torch.stack([xl, yl], -1) |
| 278 | + upper = torch.stack([xu, yu], -1) |
| 279 | + bounds = torch.stack([lower, upper], -1) |
| 280 | + deltas = safe_mul(corr, bounds) |
| 281 | + |
| 282 | + # Compute densities and conditional probabilities |
| 283 | + density_at_bounds = phi(bounds) |
| 284 | + prob_given_bounds = Phi( |
| 285 | + safe_mul(istd, safe_sub(upper.flip(-1).unsqueeze(-1), deltas)) |
| 286 | + ) - Phi(safe_mul(istd, safe_sub(lower.flip(-1).unsqueeze(-1), deltas))) |
| 287 | + |
| 288 | + # Evaluate Muthen's formula |
| 289 | + p_diffs = -(density_at_bounds * prob_given_bounds).diff().squeeze(-1) |
| 290 | + moments = (1 / p).unsqueeze(-1) * (p_diffs + r.unsqueeze(-1) * p_diffs.flip(-1)) |
| 291 | + return moments.unbind(-1) |
0 commit comments