Skip to content

Commit 5f58208

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix boundary handling in sample_polytope (#2353)
Summary: Pull Request resolved: #2353 `sample_polytope` would set both `alpha_min = 0` and `alpha_max = 0` when `x` was at the boundary, leading to it getting stuck and returning the same point. Fixes #2351 Reviewed By: Balandat Differential Revision: D57883949 fbshipit-source-id: 48433e94739f60c38cbd028c8dd48b919ebff9c3
1 parent 35b61cd commit 5f58208

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

botorch/utils/sampling.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,35 +239,59 @@ def sample_polytope(
239239
Returns:
240240
(n, d) dim Tensor containing the resulting samples.
241241
"""
242+
# Check that starting point satisfies the constraints.
243+
if not ((slack := A @ x0 - b) <= 0).all():
244+
raise ValueError(
245+
f"Starting point does not satisfy the constraints. Inputs: {A=},"
246+
f"{b=}, {x0=}, A@x0-b={slack}."
247+
)
248+
# Remove rows where all elements of A are 0. This avoids nan and infs later.
249+
# A may have zero rows in it when this is called from PolytopeSampler
250+
# with equality constraints (which are absorbed into A & b).
251+
non_zero_rows = torch.any(A != 0, dim=-1)
252+
A = A[non_zero_rows]
253+
b = b[non_zero_rows]
254+
242255
n_tot = n + n0
243256
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
244257
with manual_seed(seed=seed):
245258
rands = torch.rand(n_tot, dtype=A.dtype, device=A.device)
246259

247-
# pre-sample samples from hypersphere
248-
d = x0.size(0)
249-
# uniform samples from unit ball in d dims
250-
# increment seed by +1 to avoid correlation with step size, see #2156 for details
260+
# Sample uniformly from unit hypersphere in d dims.
261+
# Increment seed by +1 to avoid correlation with step size, see #2156 for details.
251262
Rs = sample_hypersphere(
252-
d=d, n=n_tot, dtype=A.dtype, device=A.device, seed=seed + 1
263+
d=x0.shape[0], n=n_tot, dtype=A.dtype, device=A.device, seed=seed + 1
253264
).unsqueeze(-1)
254265

255-
# compute matprods in batch
266+
# Use batch operations for matrix multiplication.
256267
ARs = (A @ Rs).squeeze(-1)
257268
out = torch.empty(n, A.size(-1), dtype=A.dtype, device=A.device)
258269
x = x0.clone()
270+
large_constant = torch.finfo().max
259271
for i, (ar, r, rnd) in enumerate(zip(ARs, Rs, rands)):
260-
# given x, the next point in the chain is x+alpha*r
261-
# it also satisfies A(x+alpha*r)<=b which implies A*alpha*r<=b-Ax
272+
# Given x, the next point in the chain is x+alpha*r.
273+
# It must satisfy A(x+alpha*r)<=b, which implies A*alpha*r<=b-Ax,
262274
# so alpha<=(b-Ax)/ar for ar>0, and alpha>=(b-Ax)/ar for ar<0.
263-
# b - A @ x is always >= 0, clamping for numerical tolerances
275+
# If x is at the boundary, b - Ax = 0. If ar > 0, then we must
276+
# have alpha <= 0. If ar < 0, we must have alpha >= 0.
277+
# ar == 0 is an unlikely event that provides no signal.
278+
# b - A @ x is always >= 0, clamping for numerical tolerances.
264279
w = (b - A @ x).squeeze().clamp(min=0.0) / ar
265-
pos = w >= 0
266-
alpha_max = w[pos].min()
267-
# important to include equality here in cases x is at the boundary
268-
# of the polytope
269-
neg = w <= 0
270-
alpha_min = w[neg].max()
280+
# Find upper bound for alpha. If there are no constraints on
281+
# the upper bound of alpha, set it to a large value.
282+
pos = w > 0
283+
alpha_max = w[pos].min().item() if pos.any() else large_constant
284+
# Find lower bound for alpha.
285+
neg = w < 0
286+
alpha_min = w[neg].max().item() if neg.any() else -large_constant
287+
# Handle the boundary case.
288+
if (w_eq_0 := (w == 0)).any():
289+
# If ar > 0 at the boundary, alpha <= 0.
290+
if w_eq_0.logical_and(ar > 0).any():
291+
alpha_max = min(alpha_max, 0.0)
292+
# If ar < 0 at the boundary, alpha >= 0.
293+
if w_eq_0.logical_and(ar < 0).any():
294+
alpha_min = max(alpha_min, 0.0)
271295
# alpha~Unif[alpha_min, alpha_max]
272296
alpha = alpha_min + rnd * (alpha_max - alpha_min)
273297
x = x + alpha * r

test/utils/test_sampling.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import itertools
1010
import warnings
11+
from abc import ABC
1112
from typing import Any, Dict, Type
1213
from unittest import mock
1314

@@ -29,6 +30,7 @@
2930
optimize_posterior_samples,
3031
PolytopeSampler,
3132
sample_hypersphere,
33+
sample_polytope,
3234
sample_simplex,
3335
sparse_to_dense_constraints,
3436
)
@@ -304,8 +306,36 @@ def test_get_polytope_samples(self):
304306
).draw(15, seed=0)[::3]
305307
self.assertTrue(torch.equal(samps, expected_samps))
306308

309+
def test_sample_polytope_infeasible(self) -> None:
310+
with self.assertRaisesRegex(ValueError, "Starting point does not satisfy"):
311+
sample_polytope(
312+
A=torch.tensor([[0.0, 0.0]]),
313+
b=torch.tensor([[-1.0]]),
314+
x0=torch.tensor([[0.0], [0.0]]),
315+
)
316+
317+
def test_sample_polytope_boundary(self) -> None:
318+
# Check that sample_polytope does not get stuck at the boundary.
319+
# This replicates https://github.com/pytorch/botorch/issues/2351.
320+
samples = sample_polytope(
321+
A=torch.tensor(
322+
[
323+
[-1.0, -1.0],
324+
[0.0, 0.0],
325+
[-1.0, 0.0],
326+
[0.0, -1.0],
327+
[0.0, 0.0],
328+
[1.0, 0.0],
329+
[0.0, 1.0],
330+
]
331+
),
332+
b=torch.tensor([[1.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0]]),
333+
x0=torch.tensor([[0.0], [0.0]]),
334+
)
335+
self.assertFalse((samples == 0).all())
336+
307337

308-
class PolytopeSamplerTestBase:
338+
class PolytopeSamplerTestBase(ABC):
309339
sampler_class: Type[PolytopeSampler]
310340
sampler_kwargs: Dict[str, Any] = {}
311341

0 commit comments

Comments
 (0)