Skip to content

Commit fa9811f

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Back out "Removing custom BlockDiagLazyTensor logic when using Standardize" (#1383)
Summary: Pull Request resolved: #1383 With this change the `Multi_objective_multi_fidelity_BO` tutorial is taking much much longer to run (4s vs 50s). Reverting this to avoid unexpected regressions in other places. Reviewed By: Balandat, SebastianAment Differential Revision: D39361738 fbshipit-source-id: 831df5033c94e4d562313d9e932b0febc18235cb
1 parent 836fde0 commit fa9811f

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

botorch/models/transforms/outcome.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
)
3434
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
3535
from botorch.utils.transforms import normalize_indices
36-
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
36+
from linear_operator.operators import (
37+
BlockDiagLinearOperator,
38+
CholLinearOperator,
39+
DiagLinearOperator,
40+
)
3741
from torch import Tensor
3842
from torch.nn import Module, ModuleDict
3943

@@ -382,8 +386,19 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior:
382386
else:
383387
lcv = mvn.lazy_covariance_matrix
384388
scale_fac = scale_fac.expand(lcv.shape[:-1])
385-
scale_mat = DiagLinearOperator(scale_fac)
386-
covar_tf = scale_mat @ lcv @ scale_mat
389+
# TODO: Remove the custom logic with next GPyTorch release (T126095032).
390+
if isinstance(lcv, BlockDiagLinearOperator):
391+
# Keep the block diag structure of lcv.
392+
base_lcv = lcv.base_linear_op
393+
scale_mat = DiagLinearOperator(
394+
scale_fac.view(*scale_fac.shape[:-1], lcv.num_blocks, -1)
395+
)
396+
base_lcv_tf = scale_mat @ base_lcv @ scale_mat
397+
covar_tf = BlockDiagLinearOperator(base_linear_op=base_lcv_tf)
398+
else:
399+
# allow batch-evaluation of the model
400+
scale_mat = DiagLinearOperator(scale_fac)
401+
covar_tf = scale_mat @ lcv @ scale_mat
387402

388403
kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {}
389404
mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs)

0 commit comments

Comments
 (0)