22
22
compute_derivatives_from_model_list ,
23
23
sample_discrete_parameters ,
24
24
)
25
+ from botorch .models .gpytorch import GPyTorchModel
25
26
from botorch .models .model import Model , ModelList
26
- from botorch .posteriors .gpytorch import GPyTorchPosterior
27
27
from botorch .sampling .normal import SobolQMCNormalSampler
28
28
from botorch .utils .sampling import draw_sobol_samples
29
29
from botorch .utils .transforms import is_ensemble , unnormalize
@@ -444,7 +444,7 @@ def ProbitLinkMean(mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
444
444
class SobolSensitivityGPMean :
445
445
def __init__ (
446
446
self ,
447
- model : Model , # TODO: narrow type down. E.g. ModelListGP does not work.
447
+ model : GPyTorchModel ,
448
448
bounds : torch .Tensor ,
449
449
num_mc_samples : int = 10 ** 4 ,
450
450
second_order : bool = False ,
@@ -461,7 +461,7 @@ def __init__(
461
461
first order indices, total indices and second order indices (if specified ).
462
462
463
463
Args:
464
- model: Botorch model
464
+ model: BoTorch model whose posterior is a `GPyTorchPosterior`.
465
465
bounds: `2 x d` parameter bounds over which to evaluate model sensitivity.
466
466
method: if "predictive mean", the predictive mean is used for indices
467
467
computation. If "GP samples", posterior sampling is used instead.
@@ -484,28 +484,25 @@ def __init__(
484
484
self .model = model
485
485
self .second_order = second_order
486
486
self .input_qmc = input_qmc
487
- # pyre-fixme[4]: Attribute must be annotated.
488
- self .bootstrap = num_bootstrap_samples > 1
487
+ self .bootstrap : bool = num_bootstrap_samples > 1
489
488
self .num_bootstrap_samples = num_bootstrap_samples
490
489
self .num_mc_samples = num_mc_samples
491
490
492
491
def input_function (x : Tensor ) -> Tensor :
493
492
with torch .no_grad ():
494
- means , variances = [], []
495
- # Since we're only looking at mean & variance, we can freely
496
- # use mini-batches.
497
- for x_split in x .split (split_size = mini_batch_size ):
498
- p = assert_is_instance (
499
- self .model .posterior (x_split ),
500
- GPyTorchPosterior ,
501
- )
502
- means .append (p .mean )
503
- variances .append (p .variance )
504
-
505
- cat_dim = 1 if is_ensemble (self .model ) else 0
506
- return link_function (
507
- torch .cat (means , dim = cat_dim ), torch .cat (variances , dim = cat_dim )
508
- )
493
+ # We only need variances, not covariances, so we use the batch
494
+ # dimension, turning x from (*batch_dim, n, d) to
495
+ # (*batch_dim, n, 1, d)
496
+ p = self .model .posterior (x .unsqueeze (- 2 ))
497
+ mean = p .mean .squeeze (- 2 )
498
+ variance = p .variance .squeeze (- 2 )
499
+ if is_ensemble (self .model ):
500
+ # If x has shape [n, d],
501
+ # the mean will have shape [n, s, m], where 's' is the ensemble
502
+ # size. Reshape to [s, n, m]
503
+ mean = torch .swapaxes (mean , - 2 , - 3 )
504
+ variance = torch .swapaxes (variance , - 2 , - 3 )
505
+ return link_function (mean , variance )
509
506
510
507
self .sensitivity = SobolSensitivity (
511
508
bounds = bounds ,
@@ -796,7 +793,7 @@ def second_order_indices(self) -> Tensor:
796
793
797
794
798
795
def compute_sobol_indices_from_model_list (
799
- model_list : list [Model ],
796
+ model_list : list [GPyTorchModel ],
800
797
bounds : Tensor ,
801
798
order : str = "first" ,
802
799
discrete_features : list [int ] | None = None ,
@@ -974,7 +971,7 @@ def _get_generator_and_digest(
974
971
975
972
def _get_model_per_metric (
976
973
generator : LegacyBoTorchGenerator | ModularBoTorchGenerator , metrics : list [str ]
977
- ) -> list [Model ]:
974
+ ) -> list [GPyTorchModel ]:
978
975
"""For a given TorchGenerator model, returns a list of botorch.models.model.Model
979
976
objects corresponding to - and in the same order as - the given metrics.
980
977
"""
@@ -984,7 +981,7 @@ def _get_model_per_metric(
984
981
model_idx = [generator .metric_names .index (m ) for m in metrics ]
985
982
if not isinstance (gp_model , ModelList ):
986
983
if gp_model .num_outputs == 1 : # can accept single output models
987
- return [gp_model for _ in model_idx ]
984
+ return [assert_is_instance ( gp_model , GPyTorchModel ) for _ in model_idx ]
988
985
raise NotImplementedError (
989
986
f"type(adapter.generator.model) = { type (gp_model )} , "
990
987
"but only ModelList is supported."
0 commit comments