Skip to content

Commit 9cd4dea

Browse files
jduerholtfacebook-github-bot
authored andcommitted
Add cache_root option for qNEI in get_acquisition_function (#1608)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation As discussed in #1604, this PR adds the possibility to setup qNEI with `cache_root=False` via the `get_acquistion` method, as it is already possible for qNEHVI. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1608 Test Plan: Unit tests. Reviewed By: Balandat Differential Revision: D42346514 Pulled By: saitcakmak fbshipit-source-id: 63d010c17cdca4147b7efe2cce5dc5cb62da4caa
1 parent 056e657 commit 9cd4dea

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

botorch/acquisition/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def get_acquisition_function(
131131
X_pending=X_pending,
132132
prune_baseline=kwargs.get("prune_baseline", False),
133133
marginalize_dim=kwargs.get("marginalize_dim"),
134+
cache_root=kwargs.get("cache_root", True),
134135
)
135136
elif acquisition_function_name == "qSR":
136137
return monte_carlo.qSimpleRegret(

test/acquisition/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,23 @@ def test_GetQNEI(self, mock_acqf):
217217
self.assertEqual(sampler.sample_shape, torch.Size([self.mc_samples]))
218218
self.assertEqual(sampler.seed, 1)
219219
self.assertEqual(kwargs["marginalize_dim"], 0)
220+
self.assertEqual(kwargs["cache_root"], True)
221+
# test with cache_root = False
222+
acqf = get_acquisition_function(
223+
acquisition_function_name="qNEI",
224+
model=self.model,
225+
objective=self.objective,
226+
X_observed=self.X_observed,
227+
X_pending=self.X_pending,
228+
mc_samples=self.mc_samples,
229+
seed=self.seed,
230+
marginalize_dim=0,
231+
cache_root=False,
232+
)
233+
self.assertTrue(acqf == mock_acqf.return_value)
234+
self.assertTrue(mock_acqf.call_count, 1)
235+
args, kwargs = mock_acqf.call_args
236+
self.assertEqual(kwargs["cache_root"], False)
220237
# test with non-qmc, no X_pending
221238
acqf = get_acquisition_function(
222239
acquisition_function_name="qNEI",

0 commit comments

Comments
 (0)