Skip to content

Commit 37de228

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Update fixed_features type hints and docstrings in OptimizeAcqfInputs (#3240)
Summary: The fixed_features parameter in OptimizeAcqfInputs, optimize_acqf, and optimize_acqf_cyclic now documents that values can be float (fixed across the batch) or Tensor (different fixed values per batch element, for batched optimization with different fixed features per restart). Closes #3111 Pull Request resolved: #3240 Reviewed By: esantorella Differential Revision: D97314412 Pulled By: saitcakmak fbshipit-source-id: cf302d3eb19be7ddd495ff3d803b245acb371d1b
1 parent 050f9ec commit 37de228

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

botorch/optim/optimize.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import dataclasses
1414
import warnings
15-
from collections.abc import Callable
15+
from collections.abc import Callable, Mapping
1616
from typing import Any
1717

1818
import torch
@@ -77,7 +77,7 @@ class OptimizeAcqfInputs:
7777
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None
7878
equality_constraints: list[tuple[Tensor, Tensor, float]] | None
7979
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None
80-
fixed_features: dict[int, float] | None
80+
fixed_features: Mapping[int, float | Tensor] | None
8181
post_processing_func: Callable[[Tensor], Tensor] | None
8282
batch_initial_conditions: Tensor | None
8383
return_best_only: bool
@@ -603,7 +603,7 @@ def optimize_acqf(
603603
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
604604
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
605605
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
606-
fixed_features: dict[int, float] | None = None,
606+
fixed_features: Mapping[int, float | Tensor] | None = None,
607607
post_processing_func: Callable[[Tensor], Tensor] | None = None,
608608
batch_initial_conditions: Tensor | None = None,
609609
return_best_only: bool = True,
@@ -692,7 +692,13 @@ def optimize_acqf(
692692
is set to 1, which will be done automatically if not specified in
693693
``options``.
694694
fixed_features: A map ``{feature_index: value}`` for features that
695-
should be fixed to a particular value during generation.
695+
should be fixed to a particular value during generation. The value
696+
can be a float, in which case the feature is fixed across the
697+
entire batch, or a Tensor, in which case the feature can be fixed
698+
to different values for each batch element (used for batched
699+
optimization with different fixed features per restart). When
700+
passing tensors as values, they should have shape ``b`` or
701+
``b x q``.
696702
post_processing_func: A function that post-processes an optimization
697703
result appropriately (i.e., according to ``round-trip``
698704
transformations).
@@ -824,7 +830,7 @@ def optimize_acqf_cyclic(
824830
options: dict[str, bool | float | int | str] | None = None,
825831
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
826832
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
827-
fixed_features: dict[int, float] | None = None,
833+
fixed_features: Mapping[int, float | Tensor] | None = None,
828834
post_processing_func: Callable[[Tensor], Tensor] | None = None,
829835
batch_initial_conditions: Tensor | None = None,
830836
cyclic_options: dict[str, bool | float | int | str] | None = None,
@@ -856,7 +862,10 @@ def optimize_acqf_cyclic(
856862
with each tuple encoding an inequality constraint of the form
857863
``\sum_i (X[indices[i]] * coefficients[i]) = rhs``
858864
fixed_features: A map ``{feature_index: value}`` for features that
859-
should be fixed to a particular value during generation.
865+
should be fixed to a particular value during generation. The value
866+
can be a float, in which case the feature is fixed across the
867+
entire batch, or a Tensor, in which case the feature can be fixed
868+
to different values for each batch element.
860869
post_processing_func: A function that post-processes an optimization
861870
result appropriately (i.e., according to ``round-trip``
862871
transformations).

0 commit comments

Comments
 (0)