Skip to content

Commit 5abab4d

Browse files
jduerholtfacebook-github-bot
authored andcommitted
Hotfix/polytopesampler seed (#1968)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation During bugfixing, I came around the issue that setting a seed in `get_polytope_samples` is not leading to the same samples when called several times with the same seed. This PR fixes it. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1968 Test Plan: Unit tests. Reviewed By: saitcakmak Differential Revision: D47993712 Pulled By: Balandat fbshipit-source-id: 95b3e548e78609a5d6593addbba0cf7585a76120
1 parent 9915f8a commit 5abab4d

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

botorch/utils/sampling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ def sample_polytope(
249249
# pre-sample samples from hypersphere
250250
d = x0.size(0)
251251
# uniform samples from unit ball in d dims
252-
Rs = sample_hypersphere(d=d, n=n_tot, dtype=A.dtype, device=A.device).unsqueeze(-1)
252+
Rs = sample_hypersphere(
253+
d=d, n=n_tot, dtype=A.dtype, device=A.device, seed=seed
254+
).unsqueeze(-1)
253255

254256
# compute matprods in batch
255257
ARs = (A @ Rs).squeeze(-1)

test/utils/test_sampling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,29 @@ def test_sample_polytope(self):
348348
self.assertTrue((more_samples <= bounds[1]).all())
349349
self.assertTrue((more_samples >= bounds[0]).all())
350350

351+
def test_sample_polytope_with_seed(self):
352+
for dtype in (torch.float, torch.double):
353+
A = self.A.to(dtype)
354+
b = self.b.to(dtype)
355+
x0 = self.x0.to(dtype)
356+
bounds = self.bounds.to(dtype)
357+
for interior_point in [x0, None]:
358+
sampler1 = self.sampler_class(
359+
inequality_constraints=(A, b),
360+
bounds=bounds,
361+
interior_point=interior_point,
362+
**self.sampler_kwargs,
363+
)
364+
sampler2 = self.sampler_class(
365+
inequality_constraints=(A, b),
366+
bounds=bounds,
367+
interior_point=interior_point,
368+
**self.sampler_kwargs,
369+
)
370+
samples1 = sampler1.draw(n=10, seed=42)
371+
samples2 = sampler2.draw(n=10, seed=42)
372+
self.assertTrue(torch.allclose(samples1, samples2))
373+
351374
def test_sample_polytope_with_eq_constraints(self):
352375
for dtype in (torch.float, torch.double):
353376
A = self.A.to(dtype)

0 commit comments

Comments
 (0)