Skip to content

Commit 62d1088

Browse files
Syrine Belakariafacebook-github-bot
authored andcommitted
Adding sensitivity analysis synthetic functions (#1355)
Summary: Pull Request resolved: #1355 These functions were buitl specifically for sensitivity analysis Reviewed By: bletham Differential Revision: D38404308 fbshipit-source-id: 73773339616a0aeb8a773b45b79b4d6334fabe6f
1 parent 3cca333 commit 62d1088

File tree

1 file changed

+231
-0
lines changed

1 file changed

+231
-0
lines changed
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import math
7+
from typing import List, Optional
8+
9+
import torch
10+
11+
from botorch.test_functions.synthetic import SyntheticTestFunction
12+
from torch import Tensor
13+
14+
15+
class Ishigami(SyntheticTestFunction):
16+
r"""Ishigami test function.
17+
18+
three-dimensional function (usually evaluated on `[-pi, pi]^3`):
19+
20+
f(x) = sin(x_1) + a sin(x_2)^2 + b x_3^4 sin(x_1)
21+
22+
Here `a` and `b` are constants where a=7 and b=0.1 or b=0.05
23+
Proposed to test sensitivity analysis methods because it exhibits strong
24+
nonlinearity and nonmonotonicity and a peculiar dependence on x_3.
25+
"""
26+
27+
def __init__(
28+
self, b: float = 0.1, noise_std: Optional[float] = None, negate: bool = False
29+
) -> None:
30+
if b not in (0.1, 0.05):
31+
raise ValueError("b parameter should be 0.1 or 0.05")
32+
self.dim = 3
33+
if b == 0.1:
34+
self.si = [0.3138, 0.4424, 0]
35+
self.si_t = [0.558, 0.442, 0.244]
36+
self.s_ij = [0, 0.244, 0]
37+
self.dgsm_gradient = [-0.0004, -0.0004, -0.0004]
38+
self.dgsm_gradient_abs = [1.9, 4.45, 1.97]
39+
self.dgsm_gradient_square = [7.7, 24.5, 11]
40+
elif b == 0.05:
41+
self.si = [0.218, 0.687, 0]
42+
self.si_t = [0.3131, 0.6868, 0.095]
43+
self.s_ij = [0, 0.094, 0]
44+
self.dgsm_gradient = [-0.0002, -0.0002, -0.0002]
45+
self.dgsm_gradient_abs = [1.26, 4.45, 1.97]
46+
self.dgsm_gradient_square = [2.8, 24.5, 11]
47+
self._bounds = [(-math.pi, math.pi) for _ in range(self.dim)]
48+
self.b = b
49+
self._optimizers = None
50+
super().__init__(noise_std=noise_std, negate=negate)
51+
52+
def compute_dgsm(self, X: Tensor) -> Tensor:
53+
r"""This function can be called separately to estimate the dgsm measure
54+
Those values are already added under self.dgsm_gradient"""
55+
dx_1 = torch.cos(X[..., 0]) * (1 + self.b * (X[..., 2] ** 4))
56+
dx_2 = 14 * torch.cos(X[..., 1]) * torch.sin(X[..., 1])
57+
dx_3 = 0.4 * (X[..., 2] ** 3) * torch.sin(X[..., 0])
58+
gradient_measure = [torch.mean(dx_1), torch.mean(dx_1), torch.mean(dx_1)]
59+
gradient_absolute_measure = [
60+
torch.mean(torch.abs(dx_1)),
61+
torch.mean(torch.abs(dx_2)),
62+
torch.mean(torch.abs(dx_3)),
63+
]
64+
gradient_square_measure = [
65+
torch.mean(torch.pow(dx_1, 2)),
66+
torch.mean(torch.pow(dx_2, 2)),
67+
torch.mean(torch.pow(dx_3, 2)),
68+
]
69+
return gradient_measure, gradient_absolute_measure, gradient_square_measure
70+
71+
def evaluate_true(self, X: Tensor) -> Tensor:
72+
self.to(device=X.device, dtype=X.dtype)
73+
t = (
74+
torch.sin(X[..., 0])
75+
+ 7 * (torch.sin(X[..., 1]) ** 2)
76+
+ self.b * (X[..., 2] ** 4) * torch.sin(X[..., 0])
77+
)
78+
return t
79+
80+
81+
class Gsobol(SyntheticTestFunction):
82+
r"""Gsobol test function.
83+
84+
d-dimensional function (usually evaluated on `[0, 1]^d`):
85+
86+
f(x) = Prod_{i=1}^{d} ((|4x_i-2|+a_i)/(1+a_i)), a_i >=0
87+
88+
common combinations of dimension and a vector:
89+
dim=8, a= [0, 1, 4.5, 9, 99, 99, 99, 99]
90+
dim=6, a=[0, 0.5, 3, 9, 99, 99]
91+
dim = 15, a= [1, 2, 5, 10, 20, 50, 100, 500, 1000, 1000, 1000, 1000, 1000,
92+
1000, 1000]
93+
Proposed to test sensitivity analysis methods
94+
First order Sobol indices have closed form expression S_i=V_i/V with :
95+
V_i= 1/(3(1+a_i)^2)
96+
V= Prod_{i=1}^{d} (1+V_i) - 1
97+
"""
98+
99+
def __init__(
100+
self,
101+
dim: int,
102+
a: List = None,
103+
noise_std: Optional[float] = None,
104+
negate: bool = False,
105+
) -> None:
106+
self.dim = dim
107+
self._bounds = [(0, 1) for _ in range(self.dim)]
108+
if self.dim == 6:
109+
self.a = [0, 0.5, 3, 9, 99, 99]
110+
elif self.dim == 8:
111+
self.a = [0, 1, 4.5, 9, 99, 99, 99, 99]
112+
elif self.dim == 15:
113+
self.a = [
114+
1,
115+
2,
116+
5,
117+
10,
118+
20,
119+
50,
120+
100,
121+
500,
122+
1000,
123+
1000,
124+
1000,
125+
1000,
126+
1000,
127+
1000,
128+
1000,
129+
]
130+
else:
131+
self.a = a
132+
self._optimizers = None
133+
self.optimal_sobol_indicies()
134+
super().__init__(noise_std=noise_std, negate=negate)
135+
136+
def optimal_sobol_indicies(self):
137+
vi = []
138+
for i in range(self.dim):
139+
vi.append(1 / (3 * ((1 + self.a[i]) ** 2)))
140+
self.vi = Tensor(vi)
141+
self.V = torch.prod((1 + self.vi)) - 1
142+
self.si = self.vi / self.V
143+
si_t = []
144+
for i in range(self.dim):
145+
si_t.append(
146+
(
147+
self.vi[i]
148+
* torch.prod(self.vi[:i] + 1)
149+
* torch.prod(self.vi[i + 1 :] + 1)
150+
)
151+
/ self.V
152+
)
153+
self.si_t = Tensor(si_t)
154+
155+
def evaluate_true(self, X: Tensor) -> Tensor:
156+
self.to(device=X.device, dtype=X.dtype)
157+
t = 1
158+
for i in range(self.dim):
159+
t = t * (torch.abs(4 * X[..., i] - 2) + self.a[i]) / (1 + self.a[i])
160+
return t
161+
162+
163+
class Morris(SyntheticTestFunction):
164+
r"""Morris test function.
165+
166+
20-dimensional function (usually evaluated on `[0, 1]^20`):
167+
f(x) = sum_{i=1}^20 beta_i w_i + sum_{i<j}^20 beta_ij w_i w_j
168+
+ sum_{i<j<l}^20 beta_ijl w_i w_j w_l + 5w_1 w_2 w_3 w_4
169+
Proposed to test sensitivity analysis methods
170+
"""
171+
172+
def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
173+
self.dim = 20
174+
self._bounds = [(0, 1) for _ in range(self.dim)]
175+
self._optimizers = None
176+
self.si = [
177+
0.005,
178+
0.008,
179+
0.017,
180+
0.009,
181+
0.016,
182+
0,
183+
0.069,
184+
0.1,
185+
0.15,
186+
0.1,
187+
0,
188+
0,
189+
0,
190+
0,
191+
0,
192+
0,
193+
0,
194+
0,
195+
0,
196+
0,
197+
]
198+
super().__init__(noise_std=noise_std, negate=negate)
199+
200+
def evaluate_true(self, X: Tensor) -> Tensor:
201+
self.to(device=X.device, dtype=X.dtype)
202+
W = []
203+
t1 = 0
204+
t2 = 0
205+
t3 = 0
206+
for i in range(self.dim):
207+
if i in [2, 4, 6]:
208+
wi = 2 * (1.1 * X[..., i] / (X[..., i] + 0.1) - 0.5)
209+
else:
210+
wi = 2 * (X[..., i] - 0.5)
211+
W.append(wi)
212+
if i < 10:
213+
betai = 20
214+
else:
215+
betai = (-1) ** (i + 1)
216+
t1 = t1 + betai * wi
217+
for i in range(self.dim):
218+
for j in range(i + 1, self.dim):
219+
if i < 6 or j < 6:
220+
beta_ij = -15
221+
else:
222+
beta_ij = (-1) ** (i + j + 2)
223+
t2 = t2 + beta_ij * W[i] * W[j]
224+
for k in range(j + 1, self.dim):
225+
if i < 5 or j < 5 or k < 5:
226+
beta_ijk = -10
227+
else:
228+
beta_ijk = 0
229+
t3 = t3 + beta_ijk * W[i] * W[j] * W[k]
230+
t4 = 5 * W[0] * W[1] * W[2] * W[3]
231+
return t1 + t2 + t3 + t4

0 commit comments

Comments
 (0)