Skip to content

Commit 1484c89

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Removing custom BlockDiagLazyTensor logic when using Standardize (Take 2) (#1414)
Summary: Pull Request resolved: #1414 Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Reviewed By: saitcakmak Differential Revision: D39746709 fbshipit-source-id: 286b092f073861cb52d409ef85ff3dc9047bae4a
1 parent a6cc512 commit 1484c89

File tree

1 file changed

+3
-18
lines changed

1 file changed

+3
-18
lines changed

botorch/models/transforms/outcome.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@
3333
)
3434
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
3535
from botorch.utils.transforms import normalize_indices
36-
from linear_operator.operators import (
37-
BlockDiagLinearOperator,
38-
CholLinearOperator,
39-
DiagLinearOperator,
40-
)
36+
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
4137
from torch import Tensor
4238
from torch.nn import Module, ModuleDict
4339

@@ -386,19 +382,8 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior:
386382
else:
387383
lcv = mvn.lazy_covariance_matrix
388384
scale_fac = scale_fac.expand(lcv.shape[:-1])
389-
# TODO: Remove the custom logic with next GPyTorch release (T126095032).
390-
if isinstance(lcv, BlockDiagLinearOperator):
391-
# Keep the block diag structure of lcv.
392-
base_lcv = lcv.base_linear_op
393-
scale_mat = DiagLinearOperator(
394-
scale_fac.view(*scale_fac.shape[:-1], lcv.num_blocks, -1)
395-
)
396-
base_lcv_tf = scale_mat @ base_lcv @ scale_mat
397-
covar_tf = BlockDiagLinearOperator(base_linear_op=base_lcv_tf)
398-
else:
399-
# allow batch-evaluation of the model
400-
scale_mat = DiagLinearOperator(scale_fac)
401-
covar_tf = scale_mat @ lcv @ scale_mat
385+
scale_mat = DiagLinearOperator(scale_fac)
386+
covar_tf = scale_mat @ lcv @ scale_mat
402387

403388
kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {}
404389
mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs)

0 commit comments

Comments
 (0)