Skip to content

Commit 0593421

Browse files
esantorellafacebook-github-bot
authored andcommitted
Speed up sensitivity analyses by using batch dimension (facebook#3891)
Summary: Pull Request resolved: facebook#3891 * Speed up sensitivity computations by using batch dimension -- use an x of shape [n, 1, d] instead of [n, d]. This avoids the costly creation of an (n, n) covariance matrix when we only need its diagonal elements. * Stop doing those computations in minibatches, which was done to avoid the superlinear memory usage from large batch sizes (Use mini batches in SobolSensitivityGPMean facebook#1848 ), which came from the (n,n) covariance matrix we no longer compute. Reviewed By: Balandat Differential Revision: D75712208 fbshipit-source-id: 19714227a6f0124064b3f4576ab8bd34eca3c708
1 parent f1daf09 commit 0593421

File tree

2 files changed

+26
-33
lines changed

2 files changed

+26
-33
lines changed

ax/utils/sensitivity/derivative_measures.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pyre-strict
77

8-
from collections.abc import Callable
8+
from collections.abc import Callable, Sequence
99
from copy import deepcopy
1010
from functools import partial
1111
from typing import Any
@@ -90,16 +90,12 @@ def __init__(
9090
this list are generated using an integer-valued uniform distribution,
9191
rather than the default (pseudo-)random continuous uniform distribution.
9292
"""
93-
# pyre-fixme[4]: Attribute must be annotated.
94-
self.dim = assert_is_instance(model.train_inputs, tuple)[0].shape[-1]
93+
self.dim: int = assert_is_instance(model.train_inputs, tuple)[0].shape[-1]
9594
self.derivative_gp = derivative_gp
9695
self.kernel_type = kernel_type
97-
# pyre-fixme[4]: Attribute must be annotated.
98-
self.bootstrap = num_bootstrap_samples > 1
99-
# pyre-fixme[4]: Attribute must be annotated.
100-
self.num_bootstrap_samples = (
101-
num_bootstrap_samples - 1
102-
) # deduct 1 because the first is meant to be the full grid
96+
self.bootstrap: bool = num_bootstrap_samples > 1
97+
# deduct 1 because the first is meant to be the full grid
98+
self.num_bootstrap_samples: int = num_bootstrap_samples - 1
10399
self.torch_device: torch.device = bounds.device
104100
if self.derivative_gp and (self.kernel_type is None):
105101
raise ValueError("Kernel type has to be specified to use derivative GP")
@@ -417,7 +413,7 @@ def aggregation(
417413

418414

419415
def compute_derivatives_from_model_list(
420-
model_list: list[Model],
416+
model_list: Sequence[Model],
421417
bounds: torch.Tensor,
422418
discrete_features: list[int] | None = None,
423419
**kwargs: Any,

ax/utils/sensitivity/sobol_measures.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
compute_derivatives_from_model_list,
2323
sample_discrete_parameters,
2424
)
25+
from botorch.models.gpytorch import GPyTorchModel
2526
from botorch.models.model import Model, ModelList
26-
from botorch.posteriors.gpytorch import GPyTorchPosterior
2727
from botorch.sampling.normal import SobolQMCNormalSampler
2828
from botorch.utils.sampling import draw_sobol_samples
2929
from botorch.utils.transforms import is_ensemble, unnormalize
@@ -444,7 +444,7 @@ def ProbitLinkMean(mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
444444
class SobolSensitivityGPMean:
445445
def __init__(
446446
self,
447-
model: Model, # TODO: narrow type down. E.g. ModelListGP does not work.
447+
model: GPyTorchModel,
448448
bounds: torch.Tensor,
449449
num_mc_samples: int = 10**4,
450450
second_order: bool = False,
@@ -461,7 +461,7 @@ def __init__(
461461
first order indices, total indices and second order indices (if specified ).
462462
463463
Args:
464-
model: Botorch model
464+
model: BoTorch model whose posterior is a `GPyTorchPosterior`.
465465
bounds: `2 x d` parameter bounds over which to evaluate model sensitivity.
466466
method: if "predictive mean", the predictive mean is used for indices
467467
computation. If "GP samples", posterior sampling is used instead.
@@ -484,28 +484,25 @@ def __init__(
484484
self.model = model
485485
self.second_order = second_order
486486
self.input_qmc = input_qmc
487-
# pyre-fixme[4]: Attribute must be annotated.
488-
self.bootstrap = num_bootstrap_samples > 1
487+
self.bootstrap: bool = num_bootstrap_samples > 1
489488
self.num_bootstrap_samples = num_bootstrap_samples
490489
self.num_mc_samples = num_mc_samples
491490

492491
def input_function(x: Tensor) -> Tensor:
493492
with torch.no_grad():
494-
means, variances = [], []
495-
# Since we're only looking at mean & variance, we can freely
496-
# use mini-batches.
497-
for x_split in x.split(split_size=mini_batch_size):
498-
p = assert_is_instance(
499-
self.model.posterior(x_split),
500-
GPyTorchPosterior,
501-
)
502-
means.append(p.mean)
503-
variances.append(p.variance)
504-
505-
cat_dim = 1 if is_ensemble(self.model) else 0
506-
return link_function(
507-
torch.cat(means, dim=cat_dim), torch.cat(variances, dim=cat_dim)
508-
)
493+
# We only need variances, not covariances, so we use the batch
494+
# dimension, turning x from (*batch_dim, n, d) to
495+
# (*batch_dim, n, 1, d)
496+
p = self.model.posterior(x.unsqueeze(-2))
497+
mean = p.mean.squeeze(-2)
498+
variance = p.variance.squeeze(-2)
499+
if is_ensemble(self.model):
500+
# If x has shape [n, d],
501+
# the mean will have shape [n, s, m], where 's' is the ensemble
502+
# size. Reshape to [s, n, m]
503+
mean = torch.swapaxes(mean, -2, -3)
504+
variance = torch.swapaxes(variance, -2, -3)
505+
return link_function(mean, variance)
509506

510507
self.sensitivity = SobolSensitivity(
511508
bounds=bounds,
@@ -796,7 +793,7 @@ def second_order_indices(self) -> Tensor:
796793

797794

798795
def compute_sobol_indices_from_model_list(
799-
model_list: list[Model],
796+
model_list: list[GPyTorchModel],
800797
bounds: Tensor,
801798
order: str = "first",
802799
discrete_features: list[int] | None = None,
@@ -974,7 +971,7 @@ def _get_generator_and_digest(
974971

975972
def _get_model_per_metric(
976973
generator: LegacyBoTorchGenerator | ModularBoTorchGenerator, metrics: list[str]
977-
) -> list[Model]:
974+
) -> list[GPyTorchModel]:
978975
"""For a given TorchGenerator model, returns a list of botorch.models.model.Model
979976
objects corresponding to - and in the same order as - the given metrics.
980977
"""
@@ -984,7 +981,7 @@ def _get_model_per_metric(
984981
model_idx = [generator.metric_names.index(m) for m in metrics]
985982
if not isinstance(gp_model, ModelList):
986983
if gp_model.num_outputs == 1: # can accept single output models
987-
return [gp_model for _ in model_idx]
984+
return [assert_is_instance(gp_model, GPyTorchModel) for _ in model_idx]
988985
raise NotImplementedError(
989986
f"type(adapter.generator.model) = {type(gp_model)}, "
990987
"but only ModelList is supported."

0 commit comments

Comments
 (0)