|
21 | 21 | arXiv:1901.08275v1, 2019 |
22 | 22 | """ |
23 | 23 |
|
| 24 | +from copy import deepcopy |
24 | 25 | from math import log |
25 | 26 | from typing import Callable, Optional |
26 | 27 |
|
@@ -297,7 +298,6 @@ def _compute_information_gain( |
297 | 298 | H1_hat = H1_bar - beta * (H0_bar - H0) |
298 | 299 | ig = H0 - H1_hat # batch_shape x num_fantasies |
299 | 300 | ig = ig.permute(-1, *range(ig.dim() - 1)) # num_fantasies x batch_shape |
300 | | - |
301 | 301 | return ig |
302 | 302 |
|
303 | 303 |
|
@@ -382,13 +382,29 @@ def __init__( |
382 | 382 | self.cost_aware_utility = cost_aware_utility |
383 | 383 | self.expand = expand |
384 | 384 | self.project = project |
| 385 | + self._cost_sampler = None |
| 386 | + |
385 | 387 | # @TODO make sure fidelity_dims align in project, expand & cost_aware_utility |
386 | 388 | # It seems very difficult due to the current way of handling project/expand |
387 | 389 |
|
388 | 390 | # resample max values after initializing self.project |
389 | 391 | # so that the max value samples are at the highest fidelity |
390 | 392 | self._sample_max_values() |
391 | 393 |
|
| 394 | + @property |
| 395 | + def cost_sampler(self): |
| 396 | + if self._cost_sampler is None: |
| 397 | + # Note: Using the deepcopy here is essential. Removing this poses a |
| 398 | + # problem if the base model and the cost model have a different number |
| 399 | + # of outputs or test points (this would be caused by expand), as this |
| 400 | + # would trigger re-sampling the base samples in the fantasy sampler. |
| 401 | + # By cloning the sampler here, the right thing will happen if the |
| 402 | + # the sizes are compatible, if they are not this will result in |
| 403 | + # samples being drawn using different base samples, but it will at |
| 404 | + # least avoid changing state of the fantasy sampler. |
| 405 | + self._cost_sampler = deepcopy(self.fantasies_sampler) |
| 406 | + return self._cost_sampler |
| 407 | + |
392 | 408 | @t_batch_mode_transform(expected_q=1) |
393 | 409 | def forward(self, X: Tensor) -> Tensor: |
394 | 410 | r"""Evaluates `qMultifidelityMaxValueEntropy` at the design points `X` |
@@ -422,8 +438,8 @@ def forward(self, X: Tensor) -> Tensor: |
422 | 438 | ig = self._compute_information_gain( |
423 | 439 | X=X_expand, mean_M=mean_M, variance_M=variance_M, covar_mM=covar_mM |
424 | 440 | ) |
425 | | - |
426 | | - return self.cost_aware_utility(X, ig).mean(dim=0) # average over the fantasies |
| 441 | + ig = self.cost_aware_utility(X=X, deltas=ig, sampler=self.cost_sampler) |
| 442 | + return ig.mean(dim=0) # average over the fantasies |
427 | 443 |
|
428 | 444 |
|
429 | 445 | def _sample_max_value_Thompson( |
|
0 commit comments