Skip to content

Commit a107ff9

Browse files
Balandatfacebook-github-bot
authored andcommitted
Utilites for validating data normalization and standardization (#223)
Summary: Adds utilities that make it easy to check whether input data is free of NaNs, normalized (for inputs), and standardized (for targets). Addresses part of #209 - these utilities will need to be called in the various model constructors, but that will be a separate PR. We should add some `debug` setting to `settings.py` (on by default) that calls these checks on the input data with `raise_on_fail=False`. Pull Request resolved: #223 Reviewed By: sdaulton Differential Revision: D16685285 Pulled By: Balandat fbshipit-source-id: eb54f9f4d383a1b3078f253b74f707d9c8213d86
1 parent beca4c8 commit a107ff9

File tree

7 files changed

+193
-16
lines changed

7 files changed

+193
-16
lines changed

botorch/exceptions/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22

33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

5-
from .errors import BotorchError, CandidateGenerationError, UnsupportedError
5+
from .errors import (
6+
BotorchError,
7+
CandidateGenerationError,
8+
InputDataError,
9+
UnsupportedError,
10+
)
611
from .warnings import (
712
BadInitialCandidatesWarning,
813
BotorchWarning,
14+
InputDataWarning,
915
OptimizationWarning,
1016
SamplingWarning,
1117
)
@@ -16,6 +22,8 @@
1622
"CandidateGenerationError",
1723
"UnsupportedError",
1824
"BotorchWarning",
25+
"InputDataWarning",
26+
"InputDataError",
1927
"BadInitialCandidatesWarning",
2028
"OptimizationWarning",
2129
"SamplingWarning",

botorch/exceptions/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ class CandidateGenerationError(BotorchError):
1919
pass
2020

2121

22+
class InputDataError(BotorchError):
23+
r"""Exception raised when input data does not comply with conventions."""
24+
25+
pass
26+
27+
2228
class UnsupportedError(BotorchError):
2329
r"""Currently unsupported feature."""
2430

botorch/exceptions/warnings.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@ class BotorchWarning(Warning):
1313
pass
1414

1515

16-
class OptimizationWarning(BotorchWarning):
17-
r"""Optimization-releated warnings."""
16+
class BadInitialCandidatesWarning(BotorchWarning):
17+
r"""Warning issued if set of initial candidates for optimziation is bad."""
1818

1919
pass
2020

2121

22-
class BadInitialCandidatesWarning(BotorchWarning):
23-
r"""Warning issued if set of initial candidates for optimziation is bad."""
22+
class InputDataWarning(BotorchWarning):
23+
r"""Warning raised when input data does not comply with conventions."""
24+
25+
pass
26+
27+
28+
class OptimizationWarning(BotorchWarning):
29+
r"""Optimization-releated warnings."""
2430

2531
pass
2632

botorch/models/utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
Utiltiy functions for models.
77
"""
88

9+
import warnings
910
from typing import List, Optional, Tuple
1011

1112
import torch
1213
from gpytorch.utils.broadcasting import _mul_broadcast_shape
1314
from torch import Tensor
1415

16+
from ..exceptions import InputDataError, InputDataWarning
17+
1518

1619
def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor:
1720
r"""Helper to construct input tensor with task indices.
@@ -107,3 +110,73 @@ def add_output_dim(X: Tensor, original_batch_shape: torch.Size) -> Tuple[Tensor,
107110
X = X.unsqueeze(-3)
108111
output_dim_idx = max(len(original_batch_shape), len(X_batch_shape))
109112
return X, output_dim_idx
113+
114+
115+
def check_no_nans(Z: Tensor) -> None:
116+
r"""Check that tensor does not contain NaN values.
117+
118+
Raises an InputDataError if `Z` contains NaN values.
119+
120+
Args:
121+
Z: The input tensor.
122+
"""
123+
if torch.any(torch.isnan(Z)).item():
124+
raise InputDataError("Input data contains NaN values.")
125+
126+
127+
def check_min_max_scaling(
128+
X: Tensor, strict: bool = False, atol: float = 1e-2, raise_on_fail: bool = False
129+
) -> None:
130+
r"""Check that tensor is normalized to the unit cube.
131+
132+
Args:
133+
X: A `batch_shape x n x d` input tensor. Typically the training inputs
134+
of a model.
135+
strict: If True, require `X` to be scaled to the unit cube (rather than
136+
just to be contained within the unit cube).
137+
atol: The tolerance for the boundary check. Only used if `strict=True`.
138+
raise_on_fail: If True, raise an exception instead of a warning.
139+
"""
140+
with torch.no_grad():
141+
Xmin, Xmax = torch.min(X, dim=-1)[0], torch.max(X, dim=-1)[0]
142+
msg = None
143+
if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol:
144+
msg = "scaled"
145+
if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
146+
msg = "contained"
147+
if msg is not None:
148+
msg = (
149+
f"Input data is not {msg} to the unit cube. "
150+
"Please consider min-max scaling the input data."
151+
)
152+
if raise_on_fail:
153+
raise InputDataError(msg)
154+
warnings.warn(msg, InputDataWarning)
155+
156+
157+
def check_standardization(
158+
Y: Tensor,
159+
atol_mean: float = 1e-2,
160+
atol_std: float = 1e-2,
161+
raise_on_fail: bool = False,
162+
) -> None:
163+
r"""Check that tensor is standardized (zero mean, unit variance).
164+
165+
Args:
166+
Y: The input tensor of shape `batch_shape x n x m`. Typically the
167+
train targets of a model. Standardization is checked across the
168+
`n`-dimension.
169+
atol_mean: The tolerance for the mean check.
170+
atol_std: The tolerance for the std check.
171+
raise_on_fail: If True, raise an exception instead of a warning.
172+
"""
173+
with torch.no_grad():
174+
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
175+
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
176+
msg = (
177+
"Input data is not standardized. Please consider scaling the "
178+
"input to zero mean and unit variance."
179+
)
180+
if raise_on_fail:
181+
raise InputDataError(msg)
182+
warnings.warn(msg, InputDataWarning)

test/exceptions/test_errors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from botorch.exceptions.errors import (
88
BotorchError,
99
CandidateGenerationError,
10+
InputDataError,
1011
UnsupportedError,
1112
)
1213

@@ -15,12 +16,15 @@ class TestBotorchExceptions(unittest.TestCase):
1516
def test_botorch_exception_hierarchy(self):
1617
self.assertIsInstance(BotorchError(), Exception)
1718
self.assertIsInstance(CandidateGenerationError(), BotorchError)
19+
self.assertIsInstance(InputDataError(), BotorchError)
1820
self.assertIsInstance(UnsupportedError(), BotorchError)
1921

2022
def test_raise_botorch_exceptions(self):
21-
with self.assertRaises(BotorchError):
22-
raise BotorchError("message")
23-
with self.assertRaises(CandidateGenerationError):
24-
raise CandidateGenerationError("message")
25-
with self.assertRaises(UnsupportedError):
26-
raise UnsupportedError("message")
23+
for ErrorClass in (
24+
BotorchError,
25+
CandidateGenerationError,
26+
InputDataError,
27+
UnsupportedError,
28+
):
29+
with self.assertRaises(ErrorClass):
30+
raise ErrorClass("message")

test/exceptions/test_warnings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from botorch.exceptions.warnings import (
99
BadInitialCandidatesWarning,
1010
BotorchWarning,
11+
InputDataWarning,
1112
OptimizationWarning,
1213
SamplingWarning,
1314
)
@@ -17,18 +18,20 @@ class TestBotorchWarnings(unittest.TestCase):
1718
def test_botorch_warnings_hierarchy(self):
1819
self.assertIsInstance(BotorchWarning(), Warning)
1920
self.assertIsInstance(BadInitialCandidatesWarning(), BotorchWarning)
21+
self.assertIsInstance(InputDataWarning(), BotorchWarning)
2022
self.assertIsInstance(OptimizationWarning(), BotorchWarning)
2123
self.assertIsInstance(SamplingWarning(), BotorchWarning)
2224

2325
def test_botorch_warnings(self):
2426
for WarningClass in (
2527
BotorchWarning,
2628
BadInitialCandidatesWarning,
29+
InputDataWarning,
2730
OptimizationWarning,
2831
SamplingWarning,
2932
):
30-
with warnings.catch_warnings(record=True) as w:
33+
with warnings.catch_warnings(record=True) as ws:
3134
warnings.warn("message", WarningClass)
32-
self.assertEqual(len(w), 1)
33-
self.assertTrue(issubclass(w[-1].category, WarningClass))
34-
self.assertTrue("message" in str(w[-1].message))
35+
self.assertEqual(len(ws), 1)
36+
self.assertTrue(issubclass(ws[-1].category, WarningClass))
37+
self.assertTrue("message" in str(ws[-1].message))

test/models/test_utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
44

55
import unittest
6+
import warnings
67

78
import torch
8-
from botorch.models.utils import add_output_dim, multioutput_to_batch_mode_transform
9+
from botorch.exceptions import InputDataError, InputDataWarning
10+
from botorch.models.utils import (
11+
add_output_dim,
12+
check_min_max_scaling,
13+
check_no_nans,
14+
check_standardization,
15+
multioutput_to_batch_mode_transform,
16+
)
917

1018

1119
class TestMultiOutputToBatchModeTransform(unittest.TestCase):
@@ -80,3 +88,72 @@ def test_add_output_dim(self, cuda=False):
8088
def test_add_output_dim_cuda(self, cuda=False):
8189
if torch.cuda.is_available():
8290
self.test_add_output_dim(cuda=True)
91+
92+
93+
class TestInputDataChecks(unittest.TestCase):
94+
def test_check_no_nans(self):
95+
check_no_nans(torch.tensor([1.0, 2.0]))
96+
with self.assertRaises(InputDataError):
97+
check_no_nans(torch.tensor([1.0, float("nan")]))
98+
99+
def test_check_min_max_scaling(self):
100+
# check unscaled input in unit cube
101+
X = 0.1 + 0.8 * torch.rand(4, 2, 3)
102+
with warnings.catch_warnings(record=True) as ws:
103+
check_min_max_scaling(X=X)
104+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
105+
check_min_max_scaling(X=X, raise_on_fail=True)
106+
with warnings.catch_warnings(record=True) as ws:
107+
check_min_max_scaling(X=X, strict=True)
108+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
109+
self.assertTrue(any("not scaled" in str(w.message) for w in ws))
110+
with self.assertRaises(InputDataError):
111+
check_min_max_scaling(X=X, strict=True, raise_on_fail=True)
112+
# check proper input
113+
Xmin, Xmax = X.min(dim=-1, keepdim=True)[0], X.max(dim=-1, keepdim=True)[0]
114+
Xstd = (X - Xmin) / (Xmax - Xmin)
115+
with warnings.catch_warnings(record=True) as ws:
116+
check_min_max_scaling(X=Xstd)
117+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
118+
check_min_max_scaling(X=Xstd, raise_on_fail=True)
119+
with warnings.catch_warnings(record=True) as ws:
120+
check_min_max_scaling(X=Xstd, strict=True)
121+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
122+
check_min_max_scaling(X=Xstd, strict=True, raise_on_fail=True)
123+
# check violation
124+
X[0, 0, 0] = 2
125+
with warnings.catch_warnings(record=True) as ws:
126+
check_min_max_scaling(X=X)
127+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
128+
self.assertTrue(any("not contained" in str(w.message) for w in ws))
129+
with self.assertRaises(InputDataError):
130+
check_min_max_scaling(X=X, raise_on_fail=True)
131+
with warnings.catch_warnings(record=True) as ws:
132+
check_min_max_scaling(X=X, strict=True)
133+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
134+
self.assertTrue(any("not contained" in str(w.message) for w in ws))
135+
with self.assertRaises(InputDataError):
136+
check_min_max_scaling(X=X, strict=True, raise_on_fail=True)
137+
138+
def test_check_standardization(self):
139+
Y = torch.randn(3, 4, 2)
140+
# check standardized input
141+
Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True)
142+
with warnings.catch_warnings(record=True) as ws:
143+
check_standardization(Y=Yst)
144+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
145+
check_standardization(Y=Yst, raise_on_fail=True)
146+
# check nonzero mean
147+
with warnings.catch_warnings(record=True) as ws:
148+
check_standardization(Y=Yst + 1)
149+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
150+
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
151+
with self.assertRaises(InputDataError):
152+
check_standardization(Y=Yst + 1, raise_on_fail=True)
153+
# check non-unit variance
154+
with warnings.catch_warnings(record=True) as ws:
155+
check_standardization(Y=Yst * 2)
156+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
157+
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
158+
with self.assertRaises(InputDataError):
159+
check_standardization(Y=Yst * 2, raise_on_fail=True)

0 commit comments

Comments
 (0)