Skip to content

Commit 2dda2e6

Browse files
blethamfacebook-github-bot
authored andcommitted
Add doc and tests for sensitivity functions (#1361)
Summary: Pull Request resolved: #1361 Adds a doc entry and tests for the sensitivity functions introduced in D38404308 (62d1088) Reviewed By: Balandat Differential Revision: D38923825 fbshipit-source-id: 2507b2160f04c746dd5a1024e5fd6179557f0728
1 parent 7bef251 commit 2dda2e6

File tree

4 files changed

+134
-23
lines changed

4 files changed

+134
-23
lines changed

botorch/test_functions/sensitivity_analysis.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7-
from typing import List, Optional
7+
from typing import List, Optional, Tuple
88

99
import torch
1010

@@ -27,6 +27,13 @@ class Ishigami(SyntheticTestFunction):
2727
def __init__(
2828
self, b: float = 0.1, noise_std: Optional[float] = None, negate: bool = False
2929
) -> None:
30+
r"""
31+
Args:
32+
b: the b constant, should be 0.1 or 0.05.
33+
noise_std: Standard deviation of the observation noise.
34+
negative: If True, negative the objective.
35+
"""
36+
self._optimizers = None
3037
if b not in (0.1, 0.05):
3138
raise ValueError("b parameter should be 0.1 or 0.05")
3239
self.dim = 3
@@ -46,25 +53,41 @@ def __init__(
4653
self.dgsm_gradient_square = [2.8, 24.5, 11]
4754
self._bounds = [(-math.pi, math.pi) for _ in range(self.dim)]
4855
self.b = b
49-
self._optimizers = None
5056
super().__init__(noise_std=noise_std, negate=negate)
5157

52-
def compute_dgsm(self, X: Tensor) -> Tensor:
53-
r"""This function can be called separately to estimate the dgsm measure
54-
Those values are already added under self.dgsm_gradient"""
58+
@property
59+
def _optimal_value(self) -> float:
60+
raise NotImplementedError
61+
62+
def compute_dgsm(self, X: Tensor) -> Tuple[List[float], List[float], List[float]]:
63+
r"""Compute derivative global sensitivity measures.
64+
65+
This function can be called separately to estimate the dgsm measure
66+
The exact global integrals of these values are already added under
67+
as attributes dgsm_gradient, dgsm_gradient_bas, and dgsm_gradient_square.
68+
69+
Args:
70+
X: Set of points at which to compute derivative measures.
71+
72+
Returns: The average gradient, absolute gradient, and square gradients.
73+
"""
5574
dx_1 = torch.cos(X[..., 0]) * (1 + self.b * (X[..., 2] ** 4))
5675
dx_2 = 14 * torch.cos(X[..., 1]) * torch.sin(X[..., 1])
5776
dx_3 = 0.4 * (X[..., 2] ** 3) * torch.sin(X[..., 0])
58-
gradient_measure = [torch.mean(dx_1), torch.mean(dx_1), torch.mean(dx_1)]
77+
gradient_measure = [
78+
torch.mean(dx_1).item(),
79+
torch.mean(dx_1).item(),
80+
torch.mean(dx_1).item(),
81+
]
5982
gradient_absolute_measure = [
60-
torch.mean(torch.abs(dx_1)),
61-
torch.mean(torch.abs(dx_2)),
62-
torch.mean(torch.abs(dx_3)),
83+
torch.mean(torch.abs(dx_1)).item(),
84+
torch.mean(torch.abs(dx_2)).item(),
85+
torch.mean(torch.abs(dx_3)).item(),
6386
]
6487
gradient_square_measure = [
65-
torch.mean(torch.pow(dx_1, 2)),
66-
torch.mean(torch.pow(dx_2, 2)),
67-
torch.mean(torch.pow(dx_3, 2)),
88+
torch.mean(torch.pow(dx_1, 2)).item(),
89+
torch.mean(torch.pow(dx_2, 2)).item(),
90+
torch.mean(torch.pow(dx_3, 2)).item(),
6891
]
6992
return gradient_measure, gradient_absolute_measure, gradient_square_measure
7093

@@ -83,17 +106,20 @@ class Gsobol(SyntheticTestFunction):
83106
84107
d-dimensional function (usually evaluated on `[0, 1]^d`):
85108
86-
f(x) = Prod_{i=1}^{d} ((|4x_i-2|+a_i)/(1+a_i)), a_i >=0
109+
f(x) = Prod_{i=1}\^{d} ((\|4x_i-2\|+a_i)/(1+a_i)), a_i >=0
87110
88111
common combinations of dimension and a vector:
112+
89113
dim=8, a= [0, 1, 4.5, 9, 99, 99, 99, 99]
90114
dim=6, a=[0, 0.5, 3, 9, 99, 99]
91-
dim = 15, a= [1, 2, 5, 10, 20, 50, 100, 500, 1000, 1000, 1000, 1000, 1000,
92-
1000, 1000]
115+
dim = 15, a= [1, 2, 5, 10, 20, 50, 100, 500, 1000, ..., 1000]
116+
93117
Proposed to test sensitivity analysis methods
94118
First order Sobol indices have closed form expression S_i=V_i/V with :
95-
V_i= 1/(3(1+a_i)^2)
96-
V= Prod_{i=1}^{d} (1+V_i) - 1
119+
120+
V_i= 1/(3(1+a_i)\^2)
121+
V= Prod_{i=1}\^{d} (1+V_i) - 1
122+
97123
"""
98124

99125
def __init__(
@@ -103,6 +129,14 @@ def __init__(
103129
noise_std: Optional[float] = None,
104130
negate: bool = False,
105131
) -> None:
132+
r"""
133+
Args:
134+
dim: Dimensionality of the problem. If 6, 8, or 15, will use standard a.
135+
a: a parameter, unless dim is 6, 8, or 15.
136+
noise_std: Standard deviation of observation noise.
137+
negate: Return negatie of function.
138+
"""
139+
self._optimizers = None
106140
self.dim = dim
107141
self._bounds = [(0, 1) for _ in range(self.dim)]
108142
if self.dim == 6:
@@ -129,10 +163,13 @@ def __init__(
129163
]
130164
else:
131165
self.a = a
132-
self._optimizers = None
133166
self.optimal_sobol_indicies()
134167
super().__init__(noise_std=noise_std, negate=negate)
135168

169+
@property
170+
def _optimal_value(self) -> float:
171+
raise NotImplementedError
172+
136173
def optimal_sobol_indicies(self):
137174
vi = []
138175
for i in range(self.dim):
@@ -164,15 +201,22 @@ class Morris(SyntheticTestFunction):
164201
r"""Morris test function.
165202
166203
20-dimensional function (usually evaluated on `[0, 1]^20`):
167-
f(x) = sum_{i=1}^20 beta_i w_i + sum_{i<j}^20 beta_ij w_i w_j
168-
+ sum_{i<j<l}^20 beta_ijl w_i w_j w_l + 5w_1 w_2 w_3 w_4
204+
205+
f(x) = sum_{i=1}\^20 beta_i w_i + sum_{i<j}\^20 beta_ij w_i w_j
206+
+ sum_{i<j<l}\^20 beta_ijl w_i w_j w_l + 5w_1 w_2 w_3 w_4
207+
169208
Proposed to test sensitivity analysis methods
170209
"""
171210

172211
def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
212+
r"""
213+
Args:
214+
noise_std: Standard deviation of observation noise.
215+
negate: Return negative of function.
216+
"""
217+
self._optimizers = None
173218
self.dim = 20
174219
self._bounds = [(0, 1) for _ in range(self.dim)]
175-
self._optimizers = None
176220
self.si = [
177221
0.005,
178222
0.008,
@@ -197,6 +241,10 @@ def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> N
197241
]
198242
super().__init__(noise_std=noise_std, negate=negate)
199243

244+
@property
245+
def _optimal_value(self) -> float:
246+
raise NotImplementedError
247+
200248
def evaluate_true(self, X: Tensor) -> Tensor:
201249
self.to(device=X.device, dtype=X.dtype)
202250
W = []

botorch/test_functions/synthetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class SyntheticTestFunction(BaseTestProblem):
2323
r"""Base class for synthetic test functions."""
2424

25-
_optimizers: List[Tuple[float, ...]]
25+
_optimizers: Optional[List[Tuple[float, ...]]]
2626
_optimal_value: float
2727
num_objectives: int = 1
2828

sphinx/source/test_functions.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,9 @@ Multi-Objective Synthetic Test Functions
2929
Multi-Objective Multi-Fidelity Synthetic Test Functions
3030
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3131
.. automodule:: botorch.test_functions.multi_objective_multi_fidelity
32-
:members:
32+
:members:
33+
34+
Sensitivity Analysis Test Functions
35+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36+
.. automodule:: botorch.test_functions.sensitivity_analysis
37+
:members:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.test_functions.sensitivity_analysis import Gsobol, Ishigami, Morris
9+
from botorch.utils.testing import BotorchTestCase
10+
11+
12+
class TestIshigami(BotorchTestCase):
13+
def testFunction(self):
14+
with self.assertRaises(ValueError):
15+
Ishigami(b=0.33)
16+
f = Ishigami(b=0.1)
17+
self.assertEqual(f.b, 0.1)
18+
f = Ishigami(b=0.05)
19+
self.assertEqual(f.b, 0.05)
20+
X = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])
21+
m1, m2, m3 = f.compute_dgsm(X)
22+
for m in [m1, m2, m3]:
23+
self.assertEqual(len(m), 3)
24+
Z = f.evaluate_true(X)
25+
Ztrue = torch.tensor([5.8401, 7.4245])
26+
self.assertTrue(torch.allclose(Z, Ztrue, atol=1e-3))
27+
self.assertIsNone(f._optimizers)
28+
with self.assertRaises(NotImplementedError):
29+
f.optimal_value
30+
31+
32+
class TestGsobol(BotorchTestCase):
33+
def testFunction(self):
34+
for dim in [6, 8, 15]:
35+
f = Gsobol(dim=dim)
36+
self.assertIsNotNone(f.a)
37+
self.assertEqual(len(f.a), dim)
38+
f = Gsobol(dim=3, a=[1, 2, 3])
39+
self.assertEqual(f.a, [1, 2, 3])
40+
X = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])
41+
Z = f.evaluate_true(X)
42+
Ztrue = torch.tensor([2.5, 21.0])
43+
self.assertTrue(torch.allclose(Z, Ztrue, atol=1e-3))
44+
self.assertIsNone(f._optimizers)
45+
with self.assertRaises(NotImplementedError):
46+
f.optimal_value
47+
48+
49+
class TestMorris(BotorchTestCase):
50+
def testFunction(self):
51+
f = Morris()
52+
X = torch.stack((torch.zeros(20), torch.ones(20)))
53+
Z = f.evaluate_true(X)
54+
Ztrue = torch.tensor([5163.0, -8137.0])
55+
self.assertTrue(torch.allclose(Z, Ztrue, atol=1e-3))
56+
self.assertIsNone(f._optimizers)
57+
with self.assertRaises(NotImplementedError):
58+
f.optimal_value

0 commit comments

Comments
 (0)