Skip to content

Commit f6de438

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add input scaling check to standard models (#267)
Summary: Introduces a `settings.validate_input_scaling` flag that when active results in the input data being checked for normalization/standardization. In this PR this check is on by default, we may want to make it optional if this is too much of a hassle in practice. Right now as used in the models this by default emits warnings and raises errors only on `NaN` values or negative variances. Addresses #208 Pull Request resolved: #267 Reviewed By: sdaulton Differential Revision: D17398454 Pulled By: Balandat fbshipit-source-id: ddb819a6612971a5464218aa1eca54ee281e1366
1 parent e26f973 commit f6de438

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

botorch/models/gp_regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ..sampling.samplers import MCSampler
3131
from .gpytorch import BatchedMultiOutputGPyTorchModel
32+
from .utils import validate_input_scaling
3233

3334

3435
MIN_INFERRED_NOISE_LEVEL = 1e-4
@@ -75,6 +76,7 @@ def __init__(
7576
>>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
7677
>>> model = SingleTaskGP(train_X, train_Y)
7778
"""
79+
validate_input_scaling(train_X=train_X, train_Y=train_Y)
7880
self._validate_tensor_args(X=train_X, Y=train_Y)
7981
self._set_dimensions(train_X=train_X, train_Y=train_Y)
8082
train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
@@ -143,6 +145,7 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
143145
>>> train_Yvar = torch.full_like(train_Y, 0.2)
144146
>>> model = FixedNoiseGP(train_X, train_Y, train_Yvar)
145147
"""
148+
validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
146149
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
147150
self._set_dimensions(train_X=train_X, train_Y=train_Y)
148151
train_X, train_Y, train_Yvar = self._transform_tensor_args(
@@ -238,6 +241,7 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
238241
>>> train_Yvar = 0.1 + se * torch.rand_like(train_Y)
239242
>>> model = HeteroskedasticSingleTaskGP(train_X, train_Y, train_Yvar)
240243
"""
244+
validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
241245
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
242246
self._set_dimensions(train_X=train_X, train_Y=train_Y)
243247
noise_likelihood = GaussianLikelihood(

botorch/models/multitask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
>>> train_Y = torch.cat(f1(X1), f2(X2))
7171
>>> model = MultiTaskGP(train_X, train_Y, task_feature=-1)
7272
"""
73+
# TODO: Validate input normalization/scaling
7374
if train_X.ndimension() != 2:
7475
# Currently, batch mode MTGPs are blocked upstream in GPyTorch
7576
raise ValueError(f"Unsupported shape {train_X.shape} for train_X.")

botorch/models/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from gpytorch.utils.broadcasting import _mul_broadcast_shape
1414
from torch import Tensor
1515

16+
from .. import settings
1617
from ..exceptions import InputDataError, InputDataWarning
1718

1819

@@ -179,3 +180,42 @@ def check_standardization(
179180
if raise_on_fail:
180181
raise InputDataError(msg)
181182
warnings.warn(msg, InputDataWarning)
183+
184+
185+
def validate_input_scaling(
186+
train_X: Tensor,
187+
train_Y: Tensor,
188+
train_Yvar: Optional[Tensor] = None,
189+
raise_on_fail: bool = False,
190+
) -> None:
191+
r"""Helper function to validate input data to models.
192+
193+
Args:
194+
train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of
195+
training features.
196+
train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of
197+
training observations.
198+
train_Yvar: A `batch_shape x n x m` or `batch_shape x n x m` (batch mode)
199+
tensor of observed measurement noise.
200+
raise_on_fail: If True, raise an error instead of emitting a warning
201+
(only for normalization/standardization checks, an error is always
202+
raised if NaN values are present).
203+
204+
This function is typically called inside the constructor of standard BoTorch
205+
models. It validates the following:
206+
(i) none of the inputs contain NaN values
207+
(ii) the training data (`train_X`) is normalized to the unit cube
208+
(iii) the training targets (`train_Y`) are standardized (zero mean, unit var)
209+
No checks (other than the NaN check) are performed for observed variances
210+
(`train_Yvar`) at this point.
211+
"""
212+
if settings.validate_input_scaling.off():
213+
return
214+
check_no_nans(train_X)
215+
check_no_nans(train_Y)
216+
if train_Yvar is not None:
217+
check_no_nans(train_Yvar)
218+
if torch.any(train_Yvar < 0):
219+
raise InputDataError("Input data contains negative variances.")
220+
check_min_max_scaling(X=train_X, raise_on_fail=raise_on_fail)
221+
check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)

botorch/settings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,18 @@ class debug(_Flag):
7474
def _set_state(cls, state: bool) -> None:
7575
cls._state = state
7676
suppress_botorch_warnings(suppress=not cls._state)
77+
78+
79+
class validate_input_scaling(_Flag):
80+
r"""Flag for validating input normalization/standardization.
81+
82+
When set to `True`, standard botorch models will validate (up to reasonable
83+
tolerance) that
84+
(i) none of the inputs contain NaN values
85+
(ii) the training data (`train_X`) is normalized to the unit cube
86+
(iii) the training targets (`train_Y`) are standardized (zero mean, unit var)
87+
No checks (other than the NaN check) are performed for observed variances
88+
(`train_Y_var`) at this point.
89+
"""
90+
91+
_state: bool = True

test/models/test_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
check_no_nans,
1414
check_standardization,
1515
multioutput_to_batch_mode_transform,
16+
validate_input_scaling,
1617
)
1718
from botorch.utils.testing import BotorchTestCase
1819

@@ -156,3 +157,47 @@ def test_check_standardization(self):
156157
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
157158
with self.assertRaises(InputDataError):
158159
check_standardization(Y=Yst * 2, raise_on_fail=True)
160+
161+
def test_validate_input_scaling(self):
162+
train_X = 2 + torch.rand(3, 4, 3)
163+
train_Y = torch.randn(3, 4, 2)
164+
# check that nothing is being checked
165+
with settings.validate_input_scaling(False), settings.debug(True):
166+
with warnings.catch_warnings(record=True) as ws:
167+
validate_input_scaling(train_X=train_X, train_Y=train_Y)
168+
self.assertFalse(
169+
any(issubclass(w.category, InputDataWarning) for w in ws)
170+
)
171+
# check that warnings are being issued
172+
with settings.debug(True), warnings.catch_warnings(record=True) as ws:
173+
validate_input_scaling(train_X=train_X, train_Y=train_Y)
174+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
175+
# check that errors are raised when requested
176+
with settings.debug(True):
177+
with self.assertRaises(InputDataError):
178+
validate_input_scaling(
179+
train_X=train_X, train_Y=train_Y, raise_on_fail=True
180+
)
181+
# check that no errors are being raised if everything is standardized
182+
train_X_min = train_X.min(dim=-1, keepdim=True)[0]
183+
train_X_max = train_X.max(dim=-1, keepdim=True)[0]
184+
train_X_std = (train_X - train_X_min) / (train_X_max - train_X_min)
185+
train_Y_std = (train_Y - train_Y.mean(dim=-2, keepdim=True)) / train_Y.std(
186+
dim=-2, keepdim=True
187+
)
188+
with settings.debug(True), warnings.catch_warnings(record=True) as ws:
189+
validate_input_scaling(train_X=train_X_std, train_Y=train_Y_std)
190+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
191+
# test that negative variances raise an error
192+
train_Yvar = torch.rand_like(train_Y_std)
193+
train_Yvar[0, 0, 1] = -0.5
194+
with settings.debug(True):
195+
with self.assertRaises(InputDataError):
196+
validate_input_scaling(
197+
train_X=train_X_std, train_Y=train_Y_std, train_Yvar=train_Yvar
198+
)
199+
# check that NaNs raise errors
200+
train_X_std[0, 0, 0] = float("nan")
201+
with settings.debug(True):
202+
with self.assertRaises(InputDataError):
203+
validate_input_scaling(train_X=train_X_std, train_Y=train_Y_std)

0 commit comments

Comments
 (0)