|
3 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
4 | 4 |
|
5 | 5 | import unittest |
| 6 | +import warnings |
6 | 7 |
|
7 | 8 | import torch |
8 | | -from botorch.models.utils import add_output_dim, multioutput_to_batch_mode_transform |
| 9 | +from botorch.exceptions import InputDataError, InputDataWarning |
| 10 | +from botorch.models.utils import ( |
| 11 | + add_output_dim, |
| 12 | + check_min_max_scaling, |
| 13 | + check_no_nans, |
| 14 | + check_standardization, |
| 15 | + multioutput_to_batch_mode_transform, |
| 16 | +) |
9 | 17 |
|
10 | 18 |
|
11 | 19 | class TestMultiOutputToBatchModeTransform(unittest.TestCase): |
@@ -80,3 +88,72 @@ def test_add_output_dim(self, cuda=False): |
80 | 88 | def test_add_output_dim_cuda(self, cuda=False): |
81 | 89 | if torch.cuda.is_available(): |
82 | 90 | self.test_add_output_dim(cuda=True) |
| 91 | + |
| 92 | + |
| 93 | +class TestInputDataChecks(unittest.TestCase): |
| 94 | + def test_check_no_nans(self): |
| 95 | + check_no_nans(torch.tensor([1.0, 2.0])) |
| 96 | + with self.assertRaises(InputDataError): |
| 97 | + check_no_nans(torch.tensor([1.0, float("nan")])) |
| 98 | + |
| 99 | + def test_check_min_max_scaling(self): |
| 100 | + # check unscaled input in unit cube |
| 101 | + X = 0.1 + 0.8 * torch.rand(4, 2, 3) |
| 102 | + with warnings.catch_warnings(record=True) as ws: |
| 103 | + check_min_max_scaling(X=X) |
| 104 | + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 105 | + check_min_max_scaling(X=X, raise_on_fail=True) |
| 106 | + with warnings.catch_warnings(record=True) as ws: |
| 107 | + check_min_max_scaling(X=X, strict=True) |
| 108 | + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 109 | + self.assertTrue(any("not scaled" in str(w.message) for w in ws)) |
| 110 | + with self.assertRaises(InputDataError): |
| 111 | + check_min_max_scaling(X=X, strict=True, raise_on_fail=True) |
| 112 | + # check proper input |
| 113 | + Xmin, Xmax = X.min(dim=-1, keepdim=True)[0], X.max(dim=-1, keepdim=True)[0] |
| 114 | + Xstd = (X - Xmin) / (Xmax - Xmin) |
| 115 | + with warnings.catch_warnings(record=True) as ws: |
| 116 | + check_min_max_scaling(X=Xstd) |
| 117 | + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 118 | + check_min_max_scaling(X=Xstd, raise_on_fail=True) |
| 119 | + with warnings.catch_warnings(record=True) as ws: |
| 120 | + check_min_max_scaling(X=Xstd, strict=True) |
| 121 | + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 122 | + check_min_max_scaling(X=Xstd, strict=True, raise_on_fail=True) |
| 123 | + # check violation |
| 124 | + X[0, 0, 0] = 2 |
| 125 | + with warnings.catch_warnings(record=True) as ws: |
| 126 | + check_min_max_scaling(X=X) |
| 127 | + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 128 | + self.assertTrue(any("not contained" in str(w.message) for w in ws)) |
| 129 | + with self.assertRaises(InputDataError): |
| 130 | + check_min_max_scaling(X=X, raise_on_fail=True) |
| 131 | + with warnings.catch_warnings(record=True) as ws: |
| 132 | + check_min_max_scaling(X=X, strict=True) |
| 133 | + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 134 | + self.assertTrue(any("not contained" in str(w.message) for w in ws)) |
| 135 | + with self.assertRaises(InputDataError): |
| 136 | + check_min_max_scaling(X=X, strict=True, raise_on_fail=True) |
| 137 | + |
| 138 | + def test_check_standardization(self): |
| 139 | + Y = torch.randn(3, 4, 2) |
| 140 | + # check standardized input |
| 141 | + Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True) |
| 142 | + with warnings.catch_warnings(record=True) as ws: |
| 143 | + check_standardization(Y=Yst) |
| 144 | + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 145 | + check_standardization(Y=Yst, raise_on_fail=True) |
| 146 | + # check nonzero mean |
| 147 | + with warnings.catch_warnings(record=True) as ws: |
| 148 | + check_standardization(Y=Yst + 1) |
| 149 | + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 150 | + self.assertTrue(any("not standardized" in str(w.message) for w in ws)) |
| 151 | + with self.assertRaises(InputDataError): |
| 152 | + check_standardization(Y=Yst + 1, raise_on_fail=True) |
| 153 | + # check non-unit variance |
| 154 | + with warnings.catch_warnings(record=True) as ws: |
| 155 | + check_standardization(Y=Yst * 2) |
| 156 | + self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) |
| 157 | + self.assertTrue(any("not standardized" in str(w.message) for w in ws)) |
| 158 | + with self.assertRaises(InputDataError): |
| 159 | + check_standardization(Y=Yst * 2, raise_on_fail=True) |
0 commit comments