Skip to content

Commit 82536af

Browse files
brunzemameta-codesync[bot]
authored andcommitted
Adding Kaiming/He initialization for the VBLL mean (#3053)
Summary: Hi everyone :) I have a minor update to the VBLLs. This update does not change the default behavior of the current implementation. ## Motivation In the VBLL repo, we observed improved performance on regression tasks, when applying Kaiming/He initialization to the VBLL mean (VectorInstitute/vbll@fdc6ad0). This has not yet been tested extensively in BO tasks, but we still wanted to migrate this option here to the community branch. The default initialization remains unchanged but one can choose to switch to the new init which is now also in the VBLL repo. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #3053 Test Plan: Added tests to maintain coverage. ## Related PRs - Original VBLL PR: #2754 - Commit with Kaiming/He initialization in VBLL repo: VectorInstitute/vbll@fdc6ad0 Reviewed By: mpolson64 Differential Revision: D85254301 Pulled By: Balandat fbshipit-source-id: 0ba4c3f6d3374bc2b0d8a8ee103432c20421460c
1 parent 583e8cb commit 82536af

File tree

4 files changed

+106
-6
lines changed

4 files changed

+106
-6
lines changed

botorch_community/models/vbll_helper.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Callable
1717

1818
import numpy as np
19-
2019
import torch
2120
import torch.nn as nn
2221

@@ -361,6 +360,7 @@ def __init__(
361360
out_features,
362361
regularization_weight,
363362
parameterization="dense",
363+
mean_initialization=None,
364364
prior_scale=1.0,
365365
wishart_scale=1e-2,
366366
cov_rank=None,
@@ -381,10 +381,18 @@ def __init__(
381381
parameterization : str
382382
Parameterization of covariance matrix.
383383
Currently supports {'dense', 'diagonal', 'lowrank', 'dense_precision'}
384+
mean_initialization : str or None
385+
Initialization method for the mean of the weights.
386+
Supports {'kaiming', None}. If None, weights are initialized from
387+
a standard normal distribution. Defaults to None.
384388
prior_scale : float
385389
Scale of prior covariance matrix
386390
wishart_scale : float
387391
Scale of Wishart prior on noise covariance
392+
cov_rank : int or None
393+
For 'lowrank' parameterization, the rank of the covariance matrix.
394+
clamp_noise_init : bool
395+
Whether to clamp the noise initialization to be positive.
388396
dof : float
389397
Degrees of freedom of Wishart prior on noise covariance
390398
"""
@@ -412,9 +420,26 @@ def __init__(
412420

413421
# last layer distribution
414422
self.W_dist = get_parameterization(parameterization)
415-
self.W_mean = nn.Parameter(
416-
torch.randn(out_features, in_features, dtype=self.dtype)
417-
)
423+
424+
if mean_initialization is None:
425+
self.W_mean = nn.Parameter(
426+
torch.randn(out_features, in_features, dtype=self.dtype)
427+
)
428+
elif mean_initialization == "kaiming":
429+
self.W_mean = nn.Parameter(
430+
torch.randn(out_features, in_features, dtype=self.dtype)
431+
* np.sqrt(2.0 / in_features)
432+
)
433+
elif isinstance(mean_initialization, str):
434+
raise ValueError(
435+
f"Unknown initialization method: {mean_initialization!r}. "
436+
f"Supported methods: 'kaiming'"
437+
)
438+
else:
439+
raise TypeError(
440+
f"mean_initialization must be a string or None, "
441+
f"got {type(mean_initialization).__name__}"
442+
)
418443

419444
if parameterization == "diagonal":
420445
self.W_logdiag = nn.Parameter(

botorch_community/models/vblls.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
from botorch.logging import logger
2626
from botorch.posteriors import Posterior
2727
from botorch_community.models.blls import AbstractBLLModel
28-
2928
from botorch_community.models.vbll_helper import DenseNormal, Normal, Regression
3029
from botorch_community.posteriors.bll_posterior import BLLPosterior
31-
3230
from gpytorch.distributions import MultivariateNormal
3331
from torch import Tensor
3432
from torch.optim import Optimizer
@@ -85,6 +83,7 @@ def __init__(
8583
num_layers: int = 3,
8684
parameterization: str = "dense",
8785
cov_rank: int | None = None,
86+
mean_initialization: str | None = None,
8887
prior_scale: float = 1.0,
8988
wishart_scale: float = 0.01,
9089
clamp_noise_init: bool = True,
@@ -103,6 +102,10 @@ def __init__(
103102
num_layers: Number of hidden layers in the MLP. Defaults to 3.
104103
parameterization: Parameterization of the posterior covariance of the last
105104
layer. Supports {'dense', 'diagonal', 'lowrank', 'dense_precision'}.
105+
cov_rank: For 'lowrank' parameterization, the rank of the covariance matrix.
106+
mean_initialization: Initialization method for the mean of the weights in
107+
the last layer. Supports {'kaiming', None}. If None, weights are
108+
initialized from a standard normal distribution. Defaults to None.
106109
prior_scale: Scaling factor for the prior distribution in the Bayesian last
107110
layer. Defaults to 1.0.
108111
wishart_scale: Scaling factor for the Wishart prior in the Bayesian last
@@ -177,6 +180,7 @@ def __init__(
177180
parameterization=parameterization,
178181
cov_rank=cov_rank,
179182
clamp_noise_init=clamp_noise_init,
183+
mean_initialization=mean_initialization,
180184
).to(dtype=torch.float64, device=self.device)
181185

182186
def forward(self, x: Tensor) -> Tensor:
@@ -253,6 +257,10 @@ def __init__(self, *args, **kwargs):
253257
def backbone(self):
254258
return self.model.backbone
255259

260+
@property
261+
def head(self):
262+
return self.model.head
263+
256264
def sample(self, sample_shape: torch.Size | None = None) -> nn.Module:
257265
"""Create posterior sample networks of the VBLL model. Note that posterior
258266
samples, we first sample from the posterior distribution of the last layer and

test_community/models/test_vbll_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
89
from botorch.utils.testing import BotorchTestCase
910
from botorch_community.models.vbll_helper import (
1011
DenseNormal,

test_community/models/test_vblls.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
from unittest.mock import patch
89

10+
import numpy as np
911
import torch
12+
1013
from botorch.utils.testing import BotorchTestCase
1114
from botorch_community.models.blls import AbstractBLLModel
1215
from botorch_community.models.vblls import VBLLModel
@@ -82,6 +85,69 @@ def test_initialization(self) -> None:
8285
parameterization="lowrank", # lowrank requires cov_rank
8386
)
8487

88+
def test_mean_initialization(self):
89+
"""Test different mean_initialization options."""
90+
d, num_hidden, num_outputs, num_layers = 2, 3, 1, 4
91+
92+
torch.manual_seed(0)
93+
model = VBLLModel(
94+
in_features=d,
95+
hidden_features=num_hidden,
96+
num_layers=num_layers,
97+
out_features=num_outputs,
98+
mean_initialization=None,
99+
)
100+
101+
# fix seeds to see if mean init is the same
102+
torch.manual_seed(0)
103+
model2 = VBLLModel(
104+
in_features=d,
105+
hidden_features=num_hidden,
106+
num_layers=num_layers,
107+
out_features=num_outputs,
108+
)
109+
110+
self.assertTrue(
111+
torch.allclose(model.head.W_mean, model2.head.W_mean, atol=1e-6),
112+
"mean_initialization=None should be equivalent to default initialization.",
113+
)
114+
115+
# Test kaiming initialization, check of np.sqrt is called
116+
with patch("numpy.sqrt", wraps=np.sqrt) as mock_sqrt:
117+
model = VBLLModel(
118+
in_features=d,
119+
hidden_features=num_hidden,
120+
num_layers=num_layers,
121+
out_features=num_outputs,
122+
mean_initialization="kaiming",
123+
)
124+
125+
# Verify that np.sqrt was called with the correct argument
126+
mock_sqrt.assert_called_once_with(2.0 / num_hidden)
127+
128+
# Test invalid string initialization
129+
with self.assertRaises(ValueError) as cm:
130+
model = VBLLModel(
131+
in_features=d,
132+
hidden_features=num_hidden,
133+
num_layers=num_layers,
134+
out_features=num_outputs,
135+
mean_initialization="invalid",
136+
)
137+
self.assertIn("Unknown initialization method", str(cm.exception))
138+
self.assertIn("kaiming", str(cm.exception))
139+
140+
# Test invalid type (not string or None)
141+
with self.assertRaises(TypeError) as cm:
142+
model = VBLLModel(
143+
in_features=d,
144+
hidden_features=num_hidden,
145+
num_layers=num_layers,
146+
out_features=num_outputs,
147+
mean_initialization=["kaiming"],
148+
)
149+
self.assertIn("must be a string or None", str(cm.exception))
150+
85151
def test_backbone_initialization(self) -> None:
86152
d, num_hidden = 4, 3
87153
test_backbone = torch.nn.Sequential(

0 commit comments

Comments
 (0)