1313# for 'cos' the samples have to be in [0,1]^d
1414# for a periodic function use 'per' or wavelets 'chui2', 'chui3', 'chui4' (samples have to be in [-0.5,0.5]^d here, 'chuim' are the Chui-Wang wavelets of order m)
1515
16+
1617def TestFunction (x ): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4 + f_2,3
17- return (
18- 2 * abs (x [0 ])
19- + abs (math .sin (math .pi * x [1 ] * x [2 ]))
20- + np .cos (3 + x [3 ])
21- )
18+ return 2 * abs (x [0 ]) + abs (math .sin (math .pi * x [1 ] * x [2 ])) + np .cos (3 + x [3 ])
2219
2320
2421rng = np .random .default_rng (1234 )
@@ -28,10 +25,10 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
2825##################################
2926
3027d = 8 # dimension
31- q = 2 # superposition dimension
28+ q = 2 # superposition dimension
3229M = 10000 # number of used evaluation points to train the model
3330M_test = 10000 # number of used evaluation points to test the accuracity the model
34- N = [5 ,2 ] # number of parameters, should be vector of length q:
31+ N = [5 , 2 ] # number of parameters, should be vector of length q:
3532# for wavelets the total number of parameters scales exponentially, i.e.:
3633# for q = 1 and N = [N1] the total number of parameters scales like ~O(d*2^N1)
3734# for q = 2 and N = [N1 , N2] the total number of parameters scales like ~O(d*2^N1) + O(d^2 * N2*2^N2)
@@ -42,18 +39,28 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
4239## Generation of the data ##
4340############################
4441
45- if basis == "chui2" or basis == "chui3" or basis == "chui4" or basis == "per" :
46- X = rng .random ((M , d )) - 0.5 # for perioidic approximation samples have to be in [-0.5,0.5]^d
42+ if basis == "chui2" or basis == "chui3" or basis == "chui4" or basis == "per" :
43+ X = (
44+ rng .random ((M , d )) - 0.5
45+ ) # for perioidic approximation samples have to be in [-0.5,0.5]^d
4746elif basis == "cos" :
4847 X = rng .random ((M , d ))
4948y = np .array (
5049 [TestFunction (X [i , :].T ) for i in range (M )]
5150) # evaluate the function at these points
5251
53- if basis == "chui1" or basis == "chui2" or basis == "chui3" or basis == "chui4" or basis == "per" :
54- X_test = rng .random ((M_test , d )) - 0.5 # for perioidic approximation samples have to be in [-0.5,0.5]^d
52+ if (
53+ basis == "chui1"
54+ or basis == "chui2"
55+ or basis == "chui3"
56+ or basis == "chui4"
57+ or basis == "per"
58+ ):
59+ X_test = (
60+ rng .random ((M_test , d )) - 0.5
61+ ) # for perioidic approximation samples have to be in [-0.5,0.5]^d
5562elif basis == "cos" :
56- X_test = rng .random ((M_test , d ))
63+ X_test = rng .random ((M_test , d ))
5764y_test = np .array (
5865 [TestFunction (X_test [i , :].T ) for i in range (M_test )]
5966) # the same for the test points
@@ -74,8 +81,10 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
7481#######################
7582
7683### Do sensitivity analysis ####
77- gsis = ANOVAapprox .get_GSI (anova_model ,0.0 ) #calculates indices for importance of terms (gsis is vector, with indices belonging to terms in anova_model.U)
78- #gsis_as_dict = ANOVAapprox.get_GSI(anova_model,0.0,dict=true)
84+ gsis = ANOVAapprox .get_GSI (
85+ anova_model , 0.0
86+ ) # calculates indices for importance of terms (gsis is vector, with indices belonging to terms in anova_model.U)
87+ # gsis_as_dict = ANOVAapprox.get_GSI(anova_model,0.0,dict=true)
7988
8089y_min_calc = 10 ** (np .min (np .log10 (gsis )) - 0.5 )
8190label = list (anova_model .U [1 :])
@@ -104,22 +113,22 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
104113################################
105114
106115### error analysis ###
107- mse_train = ANOVAapprox .get_mse (anova_model ,lam = 0.0 )
108- mse_test = ANOVAapprox .get_mse (anova_model ,X_test ,y_test , lam = 0.0 )
116+ mse_train = ANOVAapprox .get_mse (anova_model , lam = 0.0 )
117+ mse_test = ANOVAapprox .get_mse (anova_model , X_test , y_test , lam = 0.0 )
109118
110- print ("MSE on test points: " + str (mse_test ))
119+ print ("MSE on test points: " + str (mse_test ))
111120
112121################################################
113122## Approximation with better suited index set ##
114123################################################
115124
116125U = ANOVAapprox .get_ActiveSet (anova_model , [0.01 , 0.01 ], lam = 0.0 )
117- print ("Found index-set U: " + str (U ) )
118- anova_model = ANOVAapprox .approx (X , y , U = U , N = [i + 2 for i in N ] , basis = basis ) # increase number of paramers in N for the important terms
126+ print ("Found index-set U: " + str (U ))
127+ anova_model = ANOVAapprox .approx (
128+ X , y , U = U , N = [i + 2 for i in N ], basis = basis
129+ ) # increase number of paramers in N for the important terms
119130anova_model .approximate (lam = lambdas )
120131print ("Total number of used parameters = " + str (len (anova_model .fc [lambdas [0 ]].vec ())))
121- mse_train = ANOVAapprox .get_mse (anova_model ,lam = 0.0 )
122- mse_test = ANOVAapprox .get_mse (anova_model ,X_test ,y_test , lam = 0.0 )
123- print ("MSE on test points after ANOVA truncation: " + str (mse_test ))
124-
125-
132+ mse_train = ANOVAapprox .get_mse (anova_model , lam = 0.0 )
133+ mse_test = ANOVAapprox .get_mse (anova_model , X_test , y_test , lam = 0.0 )
134+ print ("MSE on test points after ANOVA truncation: " + str (mse_test ))
0 commit comments