|
12 | 12 |
|
13 | 13 | import dataclasses |
14 | 14 | import warnings |
15 | | -from collections.abc import Callable |
| 15 | +from collections.abc import Callable, Mapping |
16 | 16 | from typing import Any |
17 | 17 |
|
18 | 18 | import torch |
@@ -77,7 +77,7 @@ class OptimizeAcqfInputs: |
77 | 77 | inequality_constraints: list[tuple[Tensor, Tensor, float]] | None |
78 | 78 | equality_constraints: list[tuple[Tensor, Tensor, float]] | None |
79 | 79 | nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None |
80 | | - fixed_features: dict[int, float] | None |
| 80 | + fixed_features: Mapping[int, float | Tensor] | None |
81 | 81 | post_processing_func: Callable[[Tensor], Tensor] | None |
82 | 82 | batch_initial_conditions: Tensor | None |
83 | 83 | return_best_only: bool |
@@ -603,7 +603,7 @@ def optimize_acqf( |
603 | 603 | inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, |
604 | 604 | equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, |
605 | 605 | nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, |
606 | | - fixed_features: dict[int, float] | None = None, |
| 606 | + fixed_features: Mapping[int, float | Tensor] | None = None, |
607 | 607 | post_processing_func: Callable[[Tensor], Tensor] | None = None, |
608 | 608 | batch_initial_conditions: Tensor | None = None, |
609 | 609 | return_best_only: bool = True, |
@@ -692,7 +692,13 @@ def optimize_acqf( |
692 | 692 | is set to 1, which will be done automatically if not specified in |
693 | 693 | ``options``. |
694 | 694 | fixed_features: A map ``{feature_index: value}`` for features that |
695 | | - should be fixed to a particular value during generation. |
| 695 | + should be fixed to a particular value during generation. The value |
| 696 | + can be a float, in which case the feature is fixed across the |
| 697 | + entire batch, or a Tensor, in which case the feature can be fixed |
| 698 | + to different values for each batch element (used for batched |
| 699 | + optimization with different fixed features per restart). When |
| 700 | + passing tensors as values, they should have shape ``b`` or |
| 701 | + ``b x q``. |
696 | 702 | post_processing_func: A function that post-processes an optimization |
697 | 703 | result appropriately (i.e., according to ``round-trip`` |
698 | 704 | transformations). |
@@ -824,7 +830,7 @@ def optimize_acqf_cyclic( |
824 | 830 | options: dict[str, bool | float | int | str] | None = None, |
825 | 831 | inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, |
826 | 832 | equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, |
827 | | - fixed_features: dict[int, float] | None = None, |
| 833 | + fixed_features: Mapping[int, float | Tensor] | None = None, |
828 | 834 | post_processing_func: Callable[[Tensor], Tensor] | None = None, |
829 | 835 | batch_initial_conditions: Tensor | None = None, |
830 | 836 | cyclic_options: dict[str, bool | float | int | str] | None = None, |
@@ -856,7 +862,10 @@ def optimize_acqf_cyclic( |
856 | 862 | with each tuple encoding an inequality constraint of the form |
857 | 863 | ``\sum_i (X[indices[i]] * coefficients[i]) = rhs`` |
858 | 864 | fixed_features: A map ``{feature_index: value}`` for features that |
859 | | - should be fixed to a particular value during generation. |
| 865 | + should be fixed to a particular value during generation. The value |
| 866 | + can be a float, in which case the feature is fixed across the |
| 867 | + entire batch, or a Tensor, in which case the feature can be fixed |
| 868 | + to different values for each batch element. |
860 | 869 | post_processing_func: A function that post-processes an optimization |
861 | 870 | result appropriately (i.e., according to ``round-trip`` |
862 | 871 | transformations). |
|
0 commit comments