1313
1414import warnings
1515from abc import ABC , abstractmethod
16- from typing import Any , Callable , Optional , Union
16+ from typing import Callable , Optional , Union
1717
1818import torch
1919from botorch import settings
@@ -38,7 +38,9 @@ class CostAwareUtility(Module, ABC):
3838 """
3939
4040 @abstractmethod
41- def forward (self , X : Tensor , deltas : Tensor , ** kwargs : Any ) -> Tensor :
41+ def forward (
42+ self , X : Tensor , deltas : Tensor , sampler : Optional [MCSampler ] = None
43+ ) -> Tensor :
4244 r"""Evaluate the cost-aware utility on the candidates and improvements.
4345
4446 Args:
@@ -47,6 +49,8 @@ def forward(self, X: Tensor, deltas: Tensor, **kwargs: Any) -> Tensor:
4749 deltas: A `num_fantasies x batch_shape`-dim Tensor of `num_fantasy`
4850 samples from the marginal improvement in utility over the
4951 current state at `X` for each t-batch.
52+ sampler: A sampler used for sampling from the posterior of the cost
53+ model. Some subclasses ignore this argument.
5054
5155 Returns:
5256 A `num_fantasies x batch_shape`-dim Tensor of cost-transformed utilities.
@@ -67,7 +71,9 @@ def __init__(self, cost: Callable[[Tensor, Tensor], Tensor]) -> None:
6771 super ().__init__ ()
6872 self ._cost_callable : Callable [[Tensor , Tensor ], Tensor ] = cost
6973
70- def forward (self , X : Tensor , deltas : Tensor , ** kwargs : Any ) -> Tensor :
74+ def forward (
75+ self , X : Tensor , deltas : Tensor , sampler : Optional [MCSampler ] = None
76+ ) -> Tensor :
7177 r"""Evaluate the cost function on the candidates and improvements.
7278
7379 Args:
@@ -76,6 +82,7 @@ def forward(self, X: Tensor, deltas: Tensor, **kwargs: Any) -> Tensor:
7682 deltas: A `num_fantasies x batch_shape`-dim Tensor of `num_fantasy`
7783 samples from the marginal improvement in utility over the
7884 current state at `X` for each t-batch.
85+ sampler: Ignored.
7986
8087 Returns:
8188 A `num_fantasies x batch_shape`-dim Tensor of cost-weighted utilities.
@@ -143,7 +150,7 @@ def __init__(
143150 cost_objective = GenericMCObjective (lambda Y , X : Y .sum (dim = - 1 ))
144151
145152 self .cost_model = cost_model
146- self .cost_objective = cost_objective
153+ self .cost_objective : MCAcquisitionObjective = cost_objective
147154 self ._use_mean = use_mean
148155 self ._min_cost = min_cost
149156
@@ -153,7 +160,6 @@ def forward(
153160 deltas : Tensor ,
154161 sampler : Optional [MCSampler ] = None ,
155162 X_evaluation_mask : Optional [Tensor ] = None ,
156- ** kwargs : Any ,
157163 ) -> Tensor :
158164 r"""Evaluate the cost function on the candidates and improvements. Note
159165 that negative values of `deltas` are instead scaled by the cost, and not
0 commit comments