Skip to content

Commit 4ee9e03

Browse files
dme65facebook-github-bot
authored andcommitted
Update load_state_dict for ModelList to support fully Bayesian models (#1395)
Summary: Pull Request resolved: #1395 Update `ModelList` to initialize `SaasFullyBayesianSingleTaskGP` before loading the `state_dict`. This is needed in order to do cross-validation in MBM. Reviewed By: Balandat, qingfeng10 Differential Revision: D39402322 fbshipit-source-id: a3cb8144585cfd635178d4918084d7391f985262
1 parent e3b5a52 commit 4ee9e03

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

botorch/models/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from abc import ABC, abstractmethod
1717
from collections import defaultdict
1818
from copy import deepcopy
19-
from typing import Any, Callable, Dict, Hashable, List, Optional, Union
19+
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Union
2020

2121
import numpy as np
2222
import torch
@@ -420,3 +420,17 @@ def transform_inputs(self, X: Tensor) -> List[Tensor]:
420420
except AttributeError:
421421
transformed_X_list.append(X)
422422
return transformed_X_list
423+
424+
def load_state_dict(
425+
self, state_dict: Mapping[str, Any], strict: bool = True
426+
) -> None:
427+
"""Initialize the fully Bayesian models before loading the state dict."""
428+
for i, m in enumerate(self.models):
429+
if is_fully_bayesian(m):
430+
filtered_dict = {
431+
k.replace(f"models.{i}.", ""): v
432+
for k, v in state_dict.items()
433+
if k.startswith(f"models.{i}")
434+
}
435+
m.load_state_dict(filtered_dict)
436+
super().load_state_dict(state_dict=state_dict, strict=strict)

test/models/test_fully_bayesian.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,26 @@ def test_fit_model(self):
326326
self.assertEqual(median_lengthscale.shape, torch.Size([4]))
327327
self.assertEqual(model.num_mcmc_samples, 3)
328328

329-
# Test loading via state dict
330-
state_dict = model.state_dict()
329+
# Check the keys in the state dict
331330
true_keys = EXPECTED_KEYS_NOISE if infer_noise else EXPECTED_KEYS
332-
self.assertEqual(set(state_dict.keys()), set(true_keys))
333-
_, _, _, model_new = self._get_data_and_model(
334-
infer_noise=infer_noise, **tkwargs
335-
)
336-
self.assertEqual(model_new.state_dict(), {})
337-
model_new.load_state_dict(state_dict)
338-
self.assertEqual(model.state_dict().keys(), model_new.state_dict().keys())
339-
for k in model.state_dict().keys():
340-
self.assertTrue(
341-
(model.state_dict()[k] == model_new.state_dict()[k]).all()
331+
self.assertEqual(set(model.state_dict().keys()), set(true_keys))
332+
333+
for i in range(2): # Test loading via state dict
334+
m = model if i == 0 else ModelList(model, deterministic)
335+
state_dict = m.state_dict()
336+
_, _, _, m_new = self._get_data_and_model(
337+
infer_noise=infer_noise, **tkwargs
342338
)
343-
preds1, preds2 = model.posterior(test_X), model_new.posterior(test_X)
344-
self.assertTrue((preds1.mean == preds2.mean).all())
345-
self.assertTrue((preds1.variance == preds2.variance).all())
339+
m_new = m_new if i == 0 else ModelList(m_new, deterministic)
340+
if i == 0:
341+
self.assertEqual(m_new.state_dict(), {})
342+
m_new.load_state_dict(state_dict)
343+
self.assertEqual(m.state_dict().keys(), m_new.state_dict().keys())
344+
for k in m.state_dict().keys():
345+
self.assertTrue((m.state_dict()[k] == m_new.state_dict()[k]).all())
346+
preds1, preds2 = m.posterior(test_X), m_new.posterior(test_X)
347+
self.assertTrue(torch.equal(preds1.mean, preds2.mean))
348+
self.assertTrue(torch.equal(preds1.variance, preds2.variance))
346349

347350
# Make sure the model shapes are set correctly
348351
self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d]))

0 commit comments

Comments
 (0)