Skip to content

Commit 9a7c517

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Silence batch_initial_conditions warnings in optimize_acqf_mixed_alternating (#2749)
Summary: We were getting a bunch of warnings due to `raw_samples` and `batch_initial_conditions` not matching, despite the underlying code behaving correctly. Updated the code to avoid unnecessary warnings. This change eliminates 74 warnings in TestOptimizeAcqfMixed. Before: <img width="854" alt="Screenshot 2025-02-18 at 6 36 49 PM" src="https://github.com/user-attachments/assets/eec54ade-14e7-4a97-a474-2d9320858886" /> After: <img width="853" alt="Screenshot 2025-02-18 at 6 37 03 PM" src="https://github.com/user-attachments/assets/0052e9c9-5afb-490e-aa85-b8450d5d3846" /> Pull Request resolved: #2749 Reviewed By: dme65 Differential Revision: D69826047 Pulled By: saitcakmak fbshipit-source-id: ffb20ff4306523c016a377a28af8ac1c76714114
1 parent 052d0d8 commit 9a7c517

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

botorch/optim/optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def __post_init__(self) -> None:
116116
f"shape is {batch_initial_conditions_shape}."
117117
)
118118

119-
if len(batch_initial_conditions_shape) == 2:
119+
if (
120+
len(batch_initial_conditions_shape) == 2
121+
and self.raw_samples is not None
122+
):
120123
warnings.warn(
121124
"If using a 2-dim `batch_initial_conditions` botorch will "
122125
"default to old behavior of ignoring `num_restarts` and just "
@@ -132,6 +135,7 @@ def __post_init__(self) -> None:
132135
len(batch_initial_conditions_shape) == 3
133136
and batch_initial_conditions_shape[0] < self.num_restarts
134137
and batch_initial_conditions_shape[-2] != self.q
138+
and self.raw_samples is not None
135139
):
136140
warnings.warn(
137141
"If using a 3-dim `batch_initial_conditions` where the "

botorch/optim/optimize_mixed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def continuous_step(
533533
opt_inputs,
534534
q=1,
535535
num_restarts=1,
536+
raw_samples=None,
536537
batch_initial_conditions=current_x.unsqueeze(0),
537538
fixed_features={
538539
**dict(zip(discrete_dims.tolist(), current_x[discrete_dims])),

0 commit comments

Comments
 (0)