|
33 | 33 | ) |
34 | 34 | from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior |
35 | 35 | from botorch.utils.transforms import normalize_indices |
36 | | -from linear_operator.operators import CholLinearOperator, DiagLinearOperator |
| 36 | +from linear_operator.operators import ( |
| 37 | + BlockDiagLinearOperator, |
| 38 | + CholLinearOperator, |
| 39 | + DiagLinearOperator, |
| 40 | +) |
37 | 41 | from torch import Tensor |
38 | 42 | from torch.nn import Module, ModuleDict |
39 | 43 |
|
@@ -382,8 +386,19 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior: |
382 | 386 | else: |
383 | 387 | lcv = mvn.lazy_covariance_matrix |
384 | 388 | scale_fac = scale_fac.expand(lcv.shape[:-1]) |
385 | | - scale_mat = DiagLinearOperator(scale_fac) |
386 | | - covar_tf = scale_mat @ lcv @ scale_mat |
| 389 | + # TODO: Remove the custom logic with next GPyTorch release (T126095032). |
| 390 | + if isinstance(lcv, BlockDiagLinearOperator): |
| 391 | + # Keep the block diag structure of lcv. |
| 392 | + base_lcv = lcv.base_linear_op |
| 393 | + scale_mat = DiagLinearOperator( |
| 394 | + scale_fac.view(*scale_fac.shape[:-1], lcv.num_blocks, -1) |
| 395 | + ) |
| 396 | + base_lcv_tf = scale_mat @ base_lcv @ scale_mat |
| 397 | + covar_tf = BlockDiagLinearOperator(base_linear_op=base_lcv_tf) |
| 398 | + else: |
| 399 | + # allow batch-evaluation of the model |
| 400 | + scale_mat = DiagLinearOperator(scale_fac) |
| 401 | + covar_tf = scale_mat @ lcv @ scale_mat |
387 | 402 |
|
388 | 403 | kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {} |
389 | 404 | mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs) |
|
0 commit comments