Skip to content

Commit d13d00c

Browse files
Daniel Jiangfacebook-github-bot
authored andcommitted
make sure cost is not close to zero (#340)
Summary: Pull Request resolved: #340 to avoid numerical issues (although this does not seem to be causing issues currently after examining some runs) Reviewed By: Balandat Differential Revision: D18604233 fbshipit-source-id: b7f15c81321fcbeb905ae696601dfc4e7bcaf0d4
1 parent c426b0a commit d13d00c

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

botorch/acquisition/cost_aware.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
cost_model: Model,
9393
use_mean: bool = True,
9494
cost_objective: Optional[MCAcquisitionObjective] = None,
95+
min_cost: float = 1e-2,
9596
) -> None:
9697
r"""Cost-aware utility that weights increase in utiltiy by inverse cost.
9798
@@ -106,7 +107,8 @@ def __init__(
106107
posterior samples from the cost model. This can be used e.g. to
107108
un-transform predictions/samples of a cost model fit on the
108109
log-transformed cost (often done to ensure non-negativity).
109-
110+
min_cost: A value used to clamp the cost samples so that they are not
111+
too close to zero, which may cause numerical issues.
110112
Returns:
111113
The inverse-cost-weighted utiltiy.
112114
"""
@@ -116,6 +118,7 @@ def __init__(
116118
self.cost_model = cost_model
117119
self.cost_objective = cost_objective
118120
self._use_mean = use_mean
121+
self._min_cost = min_cost
119122

120123
def forward(
121124
self,
@@ -157,9 +160,9 @@ def forward(
157160
"Encountered negative cost values in InverseCostWeightedUtility",
158161
CostAwareWarning,
159162
)
160-
# clamp and sum cost across elements of the q-batch - this will be of
161-
# shape `num_fantasies x batch_shape` or `batch_shape`
162-
cost = cost.clamp_min(0.0).sum(dim=-1)
163+
# clamp (away from zero) and sum cost across elements of the q-batch -
164+
# this will be of shape `num_fantasies x batch_shape` or `batch_shape`
165+
cost = cost.clamp_min(self._min_cost).sum(dim=-1)
163166

164167
# if we are doing inverse weighting on the sample level, clamp numerator.
165168
if not self._use_mean:

test/acquisition/test_cost_aware.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,13 @@ def test_InverseCostWeightedUtility(self):
8383
self.assertTrue(
8484
torch.equal(ratios, deltas / samples.squeeze(-1).sum(dim=-1))
8585
)
86+
87+
# test min cost
88+
mm = MockModel(MockPosterior(mean=mean))
89+
icwu = InverseCostWeightedUtility(mm, min_cost=1.5)
90+
ratios = icwu(X, deltas)
91+
self.assertTrue(
92+
torch.equal(
93+
ratios, deltas / mean.clamp_min(1.5).squeeze(-1).sum(dim=-1)
94+
)
95+
)

0 commit comments

Comments
 (0)