Skip to content

Commit 824a4a9

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add error message for one shot acqf in optimize_acqf_discrete (#939)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation As noted in #938, use of one-shot acquisition functions in optimize_acqf_discrete leads to errors down the line that are hard to interpret. This PR adds an explicit error message noting that this is not supported. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #939 Test Plan: Units. Reviewed By: danielrjiang Differential Revision: D35750109 Pulled By: saitcakmak fbshipit-source-id: c1fc51ea5398613f44554de18b339d5c7d64524f
1 parent 9a93afb commit 824a4a9

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

botorch/optim/optimize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
OneShotAcquisitionFunction,
1919
)
2020
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
21+
from botorch.exceptions import UnsupportedError
2122
from botorch.generation.gen import gen_candidates_scipy
2223
from botorch.logging import logger
2324
from botorch.optim.initializers import (
@@ -594,6 +595,11 @@ def optimize_acqf_discrete(
594595
- a `q x d`-dim tensor of generated candidates.
595596
- an associated acquisition value.
596597
"""
598+
if isinstance(acq_function, OneShotAcquisitionFunction):
599+
raise UnsupportedError(
600+
"Discrete optimization is not supported for"
601+
"one-shot acquisition functions."
602+
)
597603
choices_batched = choices.unsqueeze(-2)
598604
if q > 1:
599605
candidate_list, acq_value_list = [], []

test/optim/test_optimize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AcquisitionFunction,
1414
OneShotAcquisitionFunction,
1515
)
16+
from botorch.exceptions import UnsupportedError
1617
from botorch.optim.optimize import (
1718
_filter_infeasible,
1819
_filter_invalid,
@@ -928,6 +929,14 @@ def test_optimize_acqf_discrete(self):
928929
self.assertTrue(torch.allclose(acq_value, expected_acq_value))
929930
self.assertTrue(torch.allclose(candidates, expected_candidates))
930931

932+
with self.assertRaises(UnsupportedError):
933+
acqf = MockOneShotAcquisitionFunction()
934+
optimize_acqf_discrete(
935+
acq_function=acqf,
936+
q=1,
937+
choices=torch.tensor([[0.5], [0.2]]),
938+
)
939+
931940
def test_optimize_acqf_discrete_local_search(self):
932941
for q, dtype in itertools.product((1, 2), (torch.float, torch.double)):
933942
tkwargs = {"device": self.device, "dtype": dtype}

0 commit comments

Comments
 (0)