Skip to content

Commit 286012d

Browse files
dme65facebook-github-bot
authored andcommitted
Support loading a state dict for SaasFullyBayesianSingleTaskGP (#1120)
Summary: X-link: facebook/Ax#1120 Pull Request resolved: #1384 `SaasFullyBayesianSingleTaskGP` currently doesn't support `load_state_dict` since the model is only initialized after fitting, so we can't load a state dict into a model that hasn't been fitted. This diff modifies `load_state_dict` to initialize the model with some dummy samples before loading the state dict. Reviewed By: saitcakmak, Balandat Differential Revision: D39358160 fbshipit-source-id: d923688bd307ae1a0bdb85e015491c50ab8e7203
1 parent fe2cbdc commit 286012d

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

botorch/models/fully_bayesian.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import math
3535
from abc import abstractmethod
36-
from typing import Any, Dict, List, Optional, Tuple, Union
36+
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
3737

3838
import pyro
3939
import torch
@@ -134,7 +134,7 @@ def load_mcmc_samples(
134134
class SaasPyroModel(PyroModel):
135135
r"""Implementation of the sparse axis-aligned subspace priors (SAAS) model.
136136
137-
The SAAS model uses sparsity-inducing priors to identift the most important
137+
The SAAS model uses sparsity-inducing priors to identify the most important
138138
parameters. This model is suitable for high-dimensional BO with potentially
139139
hundreds of tunable parameters. See [Eriksson2021saasbo]_ for more details.
140140
@@ -422,6 +422,41 @@ def load_mcmc_samples(self, mcmc_samples: Dict[str, Tensor]) -> None:
422422
self.likelihood,
423423
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
424424

425+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
426+
r"""Custom logic for loading the state dict.
427+
428+
The standard approach of calling `load_state_dict` currently doesn't play well
429+
with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module`
430+
and `likelihood` aren't initialized until the model has been fitted. The reason
431+
for this is that we don't know the number of MCMC samples until NUTS is called.
432+
Given the state dict, we can initialize a new model with some dummy samples and
433+
then load the state dict into this model. This currently only works for a
434+
`SaasPyroModel` and supporting more Pyro models likely requires moving the model
435+
construction logic into the Pyro model itself.
436+
"""
437+
438+
if not isinstance(self.pyro_model, SaasPyroModel):
439+
raise NotImplementedError("load_state_dict only works for SaasPyroModel")
440+
raw_mean = state_dict["mean_module.raw_constant"]
441+
num_mcmc_samples = len(raw_mean)
442+
dim = self.pyro_model.train_X.shape[-1]
443+
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
444+
# Load some dummy samples
445+
mcmc_samples = {
446+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
447+
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
448+
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
449+
}
450+
if self.pyro_model.train_Yvar is None:
451+
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
452+
(
453+
self.mean_module,
454+
self.covar_module,
455+
self.likelihood,
456+
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
457+
# Load the actual samples from the state dict
458+
super().load_state_dict(state_dict=state_dict, strict=strict)
459+
425460
def forward(self, X: Tensor) -> MultivariateNormal:
426461
self._check_if_fitted()
427462
return super().forward(X.unsqueeze(MCMC_DIM))

test/models/test_fully_bayesian.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@
5252
from linear_operator.operators import to_linear_operator
5353

5454

55+
EXPECTED_KEYS = [
56+
"mean_module.raw_constant",
57+
"covar_module.raw_outputscale",
58+
"covar_module.base_kernel.raw_lengthscale",
59+
"covar_module.base_kernel.raw_lengthscale_constraint.lower_bound",
60+
"covar_module.base_kernel.raw_lengthscale_constraint.upper_bound",
61+
"covar_module.raw_outputscale_constraint.lower_bound",
62+
"covar_module.raw_outputscale_constraint.upper_bound",
63+
]
64+
EXPECTED_KEYS_NOISE = EXPECTED_KEYS + [
65+
"likelihood.noise_covar.raw_noise",
66+
"likelihood.noise_covar.raw_noise_constraint.lower_bound",
67+
"likelihood.noise_covar.raw_noise_constraint.upper_bound",
68+
]
69+
70+
5571
class CustomPyroModel(PyroModel):
5672
def sample(self) -> None:
5773
pass
@@ -310,6 +326,24 @@ def test_fit_model(self):
310326
self.assertEqual(median_lengthscale.shape, torch.Size([4]))
311327
self.assertEqual(model.num_mcmc_samples, 3)
312328

329+
# Test loading via state dict
330+
state_dict = model.state_dict()
331+
true_keys = EXPECTED_KEYS_NOISE if infer_noise else EXPECTED_KEYS
332+
self.assertEqual(set(state_dict.keys()), set(true_keys))
333+
_, _, _, model_new = self._get_data_and_model(
334+
infer_noise=infer_noise, **tkwargs
335+
)
336+
self.assertEqual(model_new.state_dict(), {})
337+
model_new.load_state_dict(state_dict)
338+
self.assertEqual(model.state_dict().keys(), model_new.state_dict().keys())
339+
for k in model.state_dict().keys():
340+
self.assertTrue(
341+
(model.state_dict()[k] == model_new.state_dict()[k]).all()
342+
)
343+
preds1, preds2 = model.posterior(test_X), model_new.posterior(test_X)
344+
self.assertTrue((preds1.mean == preds2.mean).all())
345+
self.assertTrue((preds1.variance == preds2.variance).all())
346+
313347
# Make sure the model shapes are set correctly
314348
self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d]))
315349
self.assertTrue(torch.allclose(model.pyro_model.train_X, train_X))
@@ -520,6 +554,10 @@ def test_custom_pyro_model(self):
520554
train_Yvar=train_Yvar,
521555
pyro_model=CustomPyroModel(),
522556
)
557+
with self.assertRaisesRegex(
558+
NotImplementedError, "load_state_dict only works for SaasPyroModel"
559+
):
560+
model.load_state_dict({})
523561
self.assertIsInstance(model.pyro_model, CustomPyroModel)
524562
self.assertTrue(torch.allclose(model.pyro_model.train_X, train_X))
525563
self.assertTrue(torch.allclose(model.pyro_model.train_Y, train_Y))

0 commit comments

Comments
 (0)