Skip to content

Commit c9ccec6

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Improving test coverage of UnifiedSkewNormal code (#1408)
Summary: Pull Request resolved: #1408 This commit improves the test coverage of the code located in botorch/utils/probability. For the current coverage without this commit, [see here](https://app.codecov.io/gh/pytorch/botorch/pull/1394). Reviewed By: j-wilson Differential Revision: D39556258 fbshipit-source-id: 7a52dd36e326cca879ff02d56ab2461051759d4d
1 parent 0623e8a commit c9ccec6

File tree

14 files changed

+510
-94
lines changed

14 files changed

+510
-94
lines changed

botorch/utils/probability/bvn.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,16 @@ def bvn(r: Tensor, xl: Tensor, yl: Tensor, xu: Tensor, yu: Tensor) -> Tensor:
9292
def bvnu(r: Tensor, h: Tensor, k: Tensor) -> Tensor:
9393
r"""Solves for `P(x > h, y > k)` where `x` and `y` are standard bivariate normal
9494
random variables with correlation coefficient `r`. In [Genz2004bvnt]_, this is (1)
95-
```
96-
L(h, k, r) = P(x < -h, y < -k)
97-
= 1/(a 2\pi) \int_{h}^{\infty} \int_{k}^{\infty} f(x, y, r) dy dx,
98-
```
95+
96+
`L(h, k, r) = P(x < -h, y < -k) \
97+
= 1/(a 2\pi) \int_{h}^{\infty} \int_{k}^{\infty} f(x, y, r) dy dx,`
98+
9999
where `f(x, y, r) = e^{-1/(2a^2) (x^2 - 2rxy + y^2)}` and `a = (1 - r^2)^{1/2}`.
100100
101101
[Genz2004bvnt]_ report the following integation scheme incurs a maximum of 5e-16
102-
error when run in double precision: if |r| >= 0.925, use a 20-point quadrature rule
103-
on a 5th order Taylor expansion; else, numerically integrate in polar coordinates
104-
using no more than 20 quadrature points.
102+
error when run in double precision: if `|r| >= 0.925`, use a 20-point quadrature
103+
rule on a 5th order Taylor expansion; else, numerically integrate in polar
104+
coordinates using no more than 20 quadrature points.
105105
106106
Args:
107107
r: Tensor of correlation coefficients.
@@ -137,10 +137,10 @@ def _bvnu_polar(
137137
r: Tensor, h: Tensor, k: Tensor, num_points: Optional[int] = None
138138
) -> Tensor:
139139
r"""Solves for `P(x > h, y > k)` by integrating in polar coordinates as
140-
```
141-
L(h, k, r) = \Phi(-h)\Phi(-k) + 1/(2\pi) \int_{0}^{sin^{-1}(r)} f(t) dt
142-
f(t) = e^{-0.5 cos(t)^{-2} (h^2 + k^2 - 2hk sin(t))}
143-
```
140+
141+
`L(h, k, r) = \Phi(-h)\Phi(-k) + 1/(2\pi) \int_{0}^{sin^{-1}(r)} f(t) dt \
142+
f(t) = e^{-0.5 cos(t)^{-2} (h^2 + k^2 - 2hk sin(t))}`
143+
144144
For details, see Section 2.2 of [Genz2004bvnt]_.
145145
"""
146146
if num_points is None:
@@ -168,12 +168,13 @@ def _bvnu_taylor(r: Tensor, h: Tensor, k: Tensor, num_points: int = 20) -> Tenso
168168
r"""Solves for `P(x > h, y > k)` via Taylor expansion.
169169
170170
Per Section 2.3 of [Genz2004bvnt]_, the bvnu equation (1) may be rewritten as
171-
```
172-
L(h, k, r) = L(h, k, s) - s/(2\pi) \int_{0}^{a} f(x) dx
173-
f(x) = (1 - x^2){-1/2} e^{-0.5 ((h - sk)/ x)^2} e^{-shk/(1 + (1 - x^2)^{1/2})},
174-
```
171+
172+
`L(h, k, r) = L(h, k, s) - s/(2\pi) \int_{0}^{a} f(x) dx \
173+
f(x) = (1 - x^2){-1/2} e^{-0.5 ((h - sk)/ x)^2} e^{-shk/(1 + (1 - x^2)^{1/2})},`
174+
175175
where `s = sign(r)` and `a = sqrt(1 - r^{2})`. The term `L(h, k, s)` is analytic.
176-
The second integral is approximated via Taylor expansion.
176+
The second integral is approximated via Taylor expansion. See Sections 2.3 and
177+
2.4 of [Genz2004bvnt]_.
177178
"""
178179
_0, _1, _ni2, _i2pi, _sq2pi = get_constants_like(
179180
values=(0, 1, -0.5, _inv_2pi, _sqrt_2pi), ref=r
@@ -246,13 +247,13 @@ def bvnmom(
246247
r"""Computes the expected values of truncated, bivariate normal random variables.
247248
248249
Let `x` and `y` be a pair of standard bivariate normal random variables having
249-
correlation `r`. This function computes `E([x,y] | [xl,yl] < [x,y] < [xu,yu])`.
250+
correlation `r`. This function computes `E([x,y] \| [xl,yl] < [x,y] < [xu,yu])`.
250251
251252
Following [Muthen1990moments]_ equations (4) and (5), we have
252-
```
253-
E(x | [xl, yl] < [x, y] < [xu, yu])
254-
= Z^{-1} \phi(xl) P(yl < y < yu | x=xl) - \phi(xu) P(yl < y < yu | x=xu)
255-
```
253+
254+
`E(x \| [xl, yl] < [x, y] < [xu, yu]) \
255+
= Z^{-1} \phi(xl) P(yl < y < yu \| x=xl) - \phi(xu) P(yl < y < yu \| x=xu),`
256+
256257
where `Z = P([xl, yl] < [x, y] < [xu, yu])` and `\phi` is the standard normal PDF.
257258
258259
Args:
@@ -264,7 +265,8 @@ def bvnmom(
264265
p: Tensor of probabilities `P(xl < x < xu, yl < y < yu)`, same shape as `r`.
265266
266267
Returns:
267-
`E(x | [xl, yl] < [x, y] < [xu, yu])` and `E(y | [xl, yl] < [x, y] < [xu, yu])`.
268+
`E(x \| [xl, yl] < [x, y] < [xu, yu])` and
269+
`E(y \| [xl, yl] < [x, y] < [xu, yu])`.
268270
"""
269271
if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape):
270272
raise UnsupportedError("Arguments to `bvn` must have the same shape.")

botorch/utils/probability/lin_ess.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ def __init__(
9494
try:
9595
covariance_root = torch.linalg.cholesky(covariance_matrix)
9696
except RuntimeError as e:
97-
if "positive-definite" in str(e):
98-
raise ValueError(
97+
raise_e = e
98+
if "positive-definite" in str(raise_e):
99+
raise_e = ValueError(
99100
"Covariance matrix is not positive definite. "
100101
"Currently only non-degenerate distributions are supported."
101102
)
102-
else:
103-
raise e
103+
raise raise_e
104104
self._covariance_root = covariance_root
105105
self._x = self.x0.clone() # state of the sampler ("current point")
106106
# We will need the following repeatedly, let's allocate them once
@@ -216,11 +216,12 @@ def _find_active_intersections(self, nu: Tensor) -> Tensor:
216216
nu=nu, theta=theta, delta_theta=_delta_theta
217217
)
218218
theta_active = theta[active_directions.nonzero()]
219-
219+
delta_theta = _delta_theta
220220
while theta_active.numel() % 2 == 1:
221221
# Almost tangential ellipses, reduce delta_theta
222+
delta_theta /= 10
222223
active_directions = self._index_active(
223-
theta=theta, nu=nu, delta_theta=0.1 * _delta_theta
224+
theta=theta, nu=nu, delta_theta=delta_theta
224225
)
225226
theta_active = theta[active_directions.nonzero()]
226227

@@ -236,6 +237,9 @@ def _find_intersection_angles(self, nu: Tensor) -> Tensor:
236237
"""Compute all of the up to 2*n_ineq_con intersections of the ellipse
237238
and the linear constraints.
238239
240+
For background, see equation (2) in
241+
http://proceedings.mlr.press/v108/gessner20a/gessner20a.pdf
242+
239243
Args:
240244
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
241245
@@ -264,7 +268,7 @@ def _find_intersection_angles(self, nu: Tensor) -> Tensor:
264268
return torch.sort(theta).values
265269

266270
def _index_active(
267-
self, nu: Tensor, theta: Tensor, delta_theta: float = 1e-4
271+
self, nu: Tensor, theta: Tensor, delta_theta: float = _delta_theta
268272
) -> Tensor:
269273
r"""Determine active indices.
270274

botorch/utils/probability/linalg.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ def augment_cholesky(
5050
raise ValueError("One and only one of `Kba` or `Lba` must be provided.")
5151

5252
if jitter is not None:
53-
diag = Kbb.diagonal(dim1=-2, dim2=-1)
5453
Kbb = Kbb.clone()
55-
Kbb.fill_diagonal_(diag + jitter)
54+
Kbb.diagonal(dim1=-2, dim2=-1).add_(jitter)
5655

5756
if Lba is None:
5857
Lba = torch.linalg.solve_triangular(
@@ -62,7 +61,7 @@ def augment_cholesky(
6261
Lbb, info = torch.linalg.cholesky_ex(Kbb - Lba @ Lba.transpose(-2, -1))
6362
if info.any():
6463
raise NotPSDError(
65-
"Schur complement of `K` with respect to `Kaa` not PSD for the given"
64+
"Schur complement of `K` with respect to `Kaa` not PSD for the given "
6665
"Cholesky factor `Laa`"
6766
f"{'.' if jitter is None else f' and nugget jitter={jitter}.'}"
6867
)
@@ -85,19 +84,19 @@ def __post_init__(self, validate_init: bool = True):
8584

8685
if self.tril.shape[-2] != self.tril.shape[-1]:
8786
raise ValueError(
88-
f"Expected square matrices but `matrix` has shape {self.tril.shape}."
87+
f"Expected square matrices but `matrix` has shape `{self.tril.shape}`."
8988
)
9089

9190
if self.perm.shape != self.tril.shape[:-1]:
9291
raise ValueError(
9392
f"`perm` of shape `{self.perm.shape}` incompatible with "
94-
f"`matrix` of shape `{self.tril.shape}."
93+
f"`matrix` of shape `{self.tril.shape}`."
9594
)
9695

9796
if self.diag is not None and self.diag.shape != self.tril.shape[:-1]:
9897
raise ValueError(
9998
f"`diag` of shape `{self.diag.shape}` incompatible with "
100-
f"`matrix` of shape `{self.tril.shape}."
99+
f"`matrix` of shape `{self.tril.shape}`."
101100
)
102101

103102
def __getitem__(self, key: Any) -> PivotedCholesky:
@@ -135,9 +134,8 @@ def pivot_(self, pivot: LongTensor) -> None:
135134
# Perform basic swaps
136135
for key in ("perm", "diag"):
137136
tnsr = getattr(self, key, None)
138-
if tnsr is None:
139-
continue
140-
swap_along_dim_(tnsr, i=self.step, j=pivot, dim=pivot.ndim)
137+
if tnsr is not None:
138+
swap_along_dim_(tnsr, i=self.step, j=pivot, dim=tnsr.ndim - 1)
141139

142140
# Perform matrix swaps; prealloacte buffers for row/column linear indices
143141
size2 = size**2

botorch/utils/probability/truncated_multivariate_normal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,4 @@ def expand(
145145
return new
146146

147147
def __repr__(self) -> str:
148-
return super().__repr__()[:-1] + f"bounds: {self.bounds.shape})"
148+
return super().__repr__()[:-1] + f", bounds: {self.bounds.shape})"

botorch/utils/probability/unified_skew_normal.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from __future__ import annotations
88

99
from inspect import getmembers
10-
from typing import Optional, Sequence
10+
from typing import Optional, Sequence, Union
1111

1212
import torch
1313
from botorch.utils.probability.linalg import augment_cholesky, block_matrix_concat
1414
from botorch.utils.probability.mvnxpb import MVNXPB
1515
from botorch.utils.probability.truncated_multivariate_normal import (
1616
TruncatedMultivariateNormal,
1717
)
18+
from linear_operator.operators import LinearOperator
19+
from linear_operator.utils.errors import NotPSDError
1820
from torch import Tensor
1921
from torch.distributions.multivariate_normal import Distribution, MultivariateNormal
2022
from torch.distributions.utils import lazy_property
@@ -28,7 +30,7 @@ def __init__(
2830
self,
2931
trunc: TruncatedMultivariateNormal,
3032
gauss: MultivariateNormal,
31-
cross_covariance_matrix: Tensor,
33+
cross_covariance_matrix: Union[Tensor, LinearOperator],
3234
validate_args: Optional[bool] = None,
3335
):
3436
r"""Unified Skew Normal distribution of `Y | a < X < b` for jointly Gaussian
@@ -52,7 +54,10 @@ def __init__(
5254
f"{len(trunc.event_shape)}-dimensional `trunc` incompatible with"
5355
f"{len(gauss.event_shape)}-dimensional `gauss`."
5456
)
55-
57+
# LinearOperator currently doesn't support torch.linalg.solve_triangular,
58+
# so for the time being, we cast the operator to dense here
59+
if isinstance(cross_covariance_matrix, LinearOperator):
60+
cross_covariance_matrix = cross_covariance_matrix.to_dense()
5661
try:
5762
batch_shape = torch.broadcast_shapes(trunc.batch_shape, gauss.batch_shape)
5863
except RuntimeError as e:
@@ -66,13 +71,21 @@ def __init__(
6671
self.trunc = trunc
6772
self.gauss = gauss
6873
self.cross_covariance_matrix = cross_covariance_matrix
69-
if validate_args:
74+
if self._validate_args:
7075
try:
76+
# calling _orthogonalized_gauss first makes the following call
77+
# _orthogonalized_gauss.scale_tril which is used by self.rsample
7178
self._orthogonalized_gauss
7279
self.scale_tril
73-
except RuntimeError as e:
74-
if "positive-definite" in str(e):
75-
raise ValueError(
80+
except Exception as e:
81+
# error could be thrown by linalg.augment_cholesky (NotPSDError)
82+
# or torch.linalg.cholesky (with "positive-definite" in the message)
83+
if (
84+
isinstance(e, NotPSDError)
85+
or "positive-definite" in str(e)
86+
or "PositiveDefinite" in str(e)
87+
):
88+
e = ValueError(
7689
"UnifiedSkewNormal is only well-defined for positive definite"
7790
" joint covariance matrices."
7891
)
@@ -158,7 +171,10 @@ def expand(
158171
elif isinstance(obj, Distribution):
159172
new_obj = obj.expand(batch_shape=batch_shape)
160173
else:
161-
raise TypeError
174+
raise TypeError(
175+
f"Type {type(obj)} of UnifiedSkewNormal's lazy property "
176+
f"{name} not supported."
177+
)
162178

163179
setattr(new, name, new_obj)
164180
return new
@@ -203,12 +219,6 @@ def _orthogonalized_gauss(self) -> MultivariateNormal:
203219
parameters["covariance_matrix"] = (
204220
self.gauss.covariance_matrix - beta.transpose(-1, -2) @ beta
205221
)
206-
return MultivariateNormal(
207-
loc=torch.zeros_like(self.gauss.loc),
208-
scale_tril=self.scale_tril[..., -n:, -n:],
209-
validate_args=self._validate_args,
210-
)
211-
212222
return MultivariateNormal(**parameters, validate_args=self._validate_args)
213223

214224
@lazy_property

sphinx/source/acquisition.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,32 +141,32 @@ Utilities
141141
-------------------------------------------
142142

143143
Fixed Feature Acquisition Function
144-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145145
.. automodule:: botorch.acquisition.fixed_feature
146146
:members:
147147

148148
Constructors for Acquisition Function Input Arguments
149-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
150150
.. automodule:: botorch.acquisition.input_constructors
151151
:members:
152152

153153
Penalized Acquisition Function Wrapper
154-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
155155
.. automodule:: botorch.acquisition.penalized
156156
:members:
157157

158158
Proximal Acquisition Function Wrapper
159-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
159+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160160
.. automodule:: botorch.acquisition.proximal
161161
:members:
162162

163163
General Utilities for Acquisition Functions
164-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
164+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
165165
.. automodule:: botorch.acquisition.utils
166166
:members:
167167

168168

169169
Multi-Objective Utilities for Acquisition Functions
170-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171171
.. automodule:: botorch.acquisition.multi_objective.utils
172172
:members:

sphinx/source/utils.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ Feasible Volume
7272
.. automodule:: botorch.utils.feasible_volume
7373
:members:
7474

75+
Constants
76+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77+
.. automodule:: botorch.utils.constants
78+
:members:
79+
80+
Safe Math
81+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
82+
.. automodule:: botorch.utils.safe_math
83+
:members:
84+
7585
Multi-Objective Utilities
7686
-------------------------------------------
7787

@@ -114,3 +124,41 @@ Scalarization
114124
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
115125
.. automodule:: botorch.utils.multi_objective.scalarization
116126
:members:
127+
128+
Probability Utilities
129+
-------------------------------------------
130+
131+
Multivariate Gaussian Probabilities via Bivariate Conditioning
132+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
.. automodule:: botorch.utils.probability.mvnxpb
134+
:members:
135+
136+
Truncated Multivariate Normal Distribution
137+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138+
.. automodule:: botorch.utils.probability.truncated_multivariate_normal
139+
:members:
140+
141+
Unified Skew Normal Distribution
142+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143+
.. automodule:: botorch.utils.probability.unified_skew_normal
144+
:members:
145+
146+
Bivariate Normal Probabilities and Statistics
147+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
148+
.. automodule:: botorch.utils.probability.bvn
149+
:members:
150+
151+
Elliptic Slice Sampler with Linear Constraints
152+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153+
.. automodule:: botorch.utils.probability.lin_ess
154+
:members:
155+
156+
Linear Algebra Helpers
157+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
158+
.. automodule:: botorch.utils.probability.linalg
159+
:members:
160+
161+
Probability Helpers
162+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
163+
.. automodule:: botorch.utils.probability.utils
164+
:members:

0 commit comments

Comments
 (0)