Skip to content

Commit 2d87f90

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Allowing custom dimensionality and improved gradient stability in ModifiedFixedSingleSampleModel (#1732)
Summary: Pull Request resolved: #1732 - Allowing explicitly specifying the dimensionality instead of inferring it from the base model. - Adding jitter before the sqrt call so that when variance is close to zero, the gradient does not explode. Reviewed By: Ryan-Rhys Differential Revision: D43899380 fbshipit-source-id: 51d7ad75810320626f24887e13453b57453164da
1 parent 4a76513 commit 2d87f90

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

botorch/models/deterministic.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,39 @@ class FixedSingleSampleModel(DeterministicModel):
178178
We assume the outcomes are uncorrelated here.
179179
"""
180180

181-
def __init__(self, model: Model, w: Optional[Tensor] = None) -> None:
181+
def __init__(
182+
self,
183+
model: Model,
184+
w: Optional[Tensor] = None,
185+
dim: Optional[int] = None,
186+
jitter: Optional[float] = 1e-8,
187+
dtype: Optional[torch.dtype] = None,
188+
device: Optional[torch.dtype] = None,
189+
) -> None:
182190
r"""
183191
Args:
184192
model: The base model.
185193
w: A 1-d tensor with length model.num_outputs.
186194
If None, draw it from a standard normal distribution.
195+
dim: dimensionality of w.
196+
If None and w is not provided, draw w samples of size model.num_outputs.
197+
jitter: jitter value to be added for numerical stability, 1e-8 by default.
198+
dtype: dtype for w if specified
199+
device: device for w if specified
187200
"""
188201
super().__init__()
189202
self.model = model
190203
self._num_outputs = model.num_outputs
191-
self.w = torch.randn(model.num_outputs)
204+
self.jitter = jitter
205+
if w is None:
206+
self.w = (
207+
torch.randn(model.num_outputs, dtype=dtype, device=device)
208+
if dim is None
209+
else torch.randn(dim, dtype=dtype, device=device)
210+
)
211+
else:
212+
self.w = w
192213

193214
def forward(self, X: Tensor) -> Tensor:
194215
post = self.model.posterior(X)
195-
return post.mean + post.variance.sqrt() * self.w.to(X)
216+
return post.mean + torch.sqrt(post.variance + self.jitter) * self.w.to(X)

test/models/test_deterministic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def test_FixedSingleSampleModel(self):
166166
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
167167
fss_model = FixedSingleSampleModel(model=model)
168168

169+
# test without specifying w and dim
169170
test_X = torch.rand(2, 3)
170171
w = fss_model.w
171172
post = model.posterior(test_X)
@@ -175,6 +176,20 @@ def test_FixedSingleSampleModel(self):
175176

176177
self.assertTrue(hasattr(fss_model, "num_outputs"))
177178

179+
# test specifying w
180+
w = torch.randn(4)
181+
fss_model = FixedSingleSampleModel(model=model, w=w)
182+
self.assertTrue(fss_model.w.shape == w.shape)
183+
# test dim
184+
dim = 5
185+
fss_model = FixedSingleSampleModel(model=model, w=w, dim=dim)
186+
# dim should be ignored
187+
self.assertTrue(fss_model.w.shape == w.shape)
188+
# test dim when no w is provided
189+
fss_model = FixedSingleSampleModel(model=model, dim=dim)
190+
# dim should be ignored
191+
self.assertTrue(fss_model.w.shape == torch.Size([dim]))
192+
178193
# check w dtype conversion
179194
train_X_double = torch.rand(2, 3, dtype=torch.double)
180195
train_Y_double = torch.rand(2, 2, dtype=torch.double)

0 commit comments

Comments
 (0)