Skip to content

Commit 88d77f8

Browse files
committed
Merge branch 'main' of github.com:NFFT/pyANOVAapprox
2 parents 004b998 + 1527350 commit 88d77f8

File tree

7 files changed

+83
-69
lines changed

7 files changed

+83
-69
lines changed

simpleTest/exampleCheb.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111

1212

1313
def TestFunction(x):
14-
return (
15-
x[0] * x[4]
16-
+ 2
17-
- np.exp(x[3])
18-
+ np.sqrt(x[5] + 3 + x[1])
19-
)
14+
return x[0] * x[4] + 2 - np.exp(x[3]) + np.sqrt(x[5] + 3 + x[1])
2015

2116

2217
rng = np.random.default_rng(1234)
@@ -39,12 +34,15 @@ def TestFunction(x):
3934
b = M / (
4035
math.log10(M) * num
4136
) # number for the number of frequencies if we use logarithmic oversampling and distribute it evenly to all subsets
42-
bw = [math.floor(b / 2) * 2, math.floor(math.sqrt(b) / 2) * 2] # bandwidths (use even numbers)
37+
bw = [
38+
math.floor(b / 2) * 2,
39+
math.floor(math.sqrt(b) / 2) * 2,
40+
] # bandwidths (use even numbers)
4341
# Use all subsets up to ds and use bw[1] many frequences in the the subsets with one element, b[2]^2 many for subsets with two elements and so on
4442
#
4543
########### Variant 2:
4644
# used subsets:
47-
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
45+
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
4846
# (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
4947
# Bandwidths for these subsets:
5048
# N = [0 , 100, 100, 100, 100, 100, 100,
@@ -53,7 +51,7 @@ def TestFunction(x):
5351
#
5452
########### Variant 3:
5553
# used subsets:
56-
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
54+
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
5755
# (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
5856
# Bandwidths for these subsets:
5957
# N = [(), (100,), (100,), (100,), (100,), (100,), (100,),
@@ -156,7 +154,9 @@ def TestFunction(x):
156154
Umask = np.append(np.array([True]), gsis > 1e-2)
157155
U = [ads.U[i] for i in np.arange(0, len(Umask))[Umask]] # get important subsets
158156
bws = M / (math.log10(M) * (len(U) - 1)) # calculate frequencies per subset
159-
N = [math.floor(bws ** (1 / max(1, len(u))) / 2) * 2 for u in U] # distribute the frequencies evenly and make them even
157+
N = [
158+
math.floor(bws ** (1 / max(1, len(u))) / 2) * 2 for u in U
159+
] # distribute the frequencies evenly and make them even
160160
N[0] = 0
161161

162162
a = ANOVAapprox.approx(

simpleTest/exampleClassification.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import pyANOVAapprox as ANOVAapprox
1111

12+
1213
def TestFunction(x):
13-
e = (abs(x[1]+1.0j*x[2])+(np.angle(x[1]+1.0j*x[2])/(math.pi*8))) % 0.25>0.125
14-
return e*2-1
15-
14+
e = (
15+
abs(x[1] + 1.0j * x[2]) + (np.angle(x[1] + 1.0j * x[2]) / (math.pi * 8))
16+
) % 0.25 > 0.125
17+
return e * 2 - 1
18+
19+
1620
rng = np.random.default_rng(1234)
1721

1822
##################################
@@ -85,8 +89,8 @@ def TestFunction(x):
8589
## get classification accuracy ##
8690
#################################
8791

88-
y_approx = ads.evaluate(X=X_test, lam=0.0) # evaluate the classification
89-
acc = sum(np.sign(y_approx) == y_test)/M_test # calculate the accuracity
92+
y_approx = ads.evaluate(X=X_test, lam=0.0) # evaluate the classification
93+
acc = sum(np.sign(y_approx) == y_test) / M_test # calculate the accuracity
9094
print("accuracity = " + str(acc))
9195

9296
###############################################
@@ -156,30 +160,31 @@ def TestFunction(x):
156160
X, y, U, N, "cos", classification=True
157161
) # generate the data structure for the classification
158162
a.approximate(
159-
lam=lambdas,
160-
max_iter=max_iter
163+
lam=lambdas, max_iter=max_iter
161164
) # do the approximation for all specified regularisation parameters
162165

163-
y_approx = a.evaluate(X=X_test, lam=0.0) # evaluate the classification
164-
acc = sum(np.sign(y_approx) == y_test)/M_test # calculate the accuracity
166+
y_approx = a.evaluate(X=X_test, lam=0.0) # evaluate the classification
167+
acc = sum(np.sign(y_approx) == y_test) / M_test # calculate the accuracity
165168
print("accuracity = " + str(acc))
166169

167170
########################
168171
## Evaluate the model ##
169172
########################
170173

171-
#y_approx = a.evaluate(X=X_test) # evaluate the classification at the training points for the regularisation λ_min
172-
y_approx = a.evaluate(X=X_test, lam=0.0) # evaluate the classification at the points X_test for the regularisation λ_min
174+
# y_approx = a.evaluate(X=X_test) # evaluate the classification at the training points for the regularisation λ_min
175+
y_approx = a.evaluate(
176+
X=X_test, lam=0.0
177+
) # evaluate the classification at the points X_test for the regularisation λ_min
173178

174179
plt.figure()
175180
scatter = plt.scatter(
176-
X_test[:,1],
177-
X_test[:,2],
178-
c=np.sign(y_approx),
179-
cmap='winter',
180-
s=100,
181-
alpha=0.8,
182-
edgecolors='black',
183-
linewidths=0.5
181+
X_test[:, 1],
182+
X_test[:, 2],
183+
c=np.sign(y_approx),
184+
cmap="winter",
185+
s=100,
186+
alpha=0.8,
187+
edgecolors="black",
188+
linewidths=0.5,
184189
)
185190
plt.show()

simpleTest/exampleNonPeriodic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111

1212

1313
def TestFunction(x):
14-
return (
15-
x[0] * x[4]
16-
+ 2
17-
- np.exp(x[3])
18-
+ np.sqrt(x[5] + 3 + x[1])
19-
)
14+
return x[0] * x[4] + 2 - np.exp(x[3]) + np.sqrt(x[5] + 3 + x[1])
2015

2116

2217
rng = np.random.default_rng(1234)
@@ -39,12 +34,15 @@ def TestFunction(x):
3934
b = M / (
4035
math.log10(M) * num
4136
) # number for the number of frequencies if we use logarithmic oversampling and distribute it evenly to all subsets
42-
bw = [math.floor(b / 2) * 2, math.floor(math.sqrt(b) / 2) * 2] # bandwidths (use even numbers)
37+
bw = [
38+
math.floor(b / 2) * 2,
39+
math.floor(math.sqrt(b) / 2) * 2,
40+
] # bandwidths (use even numbers)
4341
# Use all subsets up to ds and use bw[1] many frequences in the the subsets with one element, b[2]^2 many for subsets with two elements and so on
4442
#
4543
########### Variant 2:
4644
# used subsets:
47-
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
45+
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
4846
# (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
4947
# Bandwidths for these subsets:
5048
# N = [0 , 100, 100, 100, 100, 100, 100,
@@ -53,7 +51,7 @@ def TestFunction(x):
5351
#
5452
########### Variant 3:
5553
# used subsets:
56-
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
54+
# U = [(), (0,), (1,), (2,), (3,), (4,), (5,),
5755
# (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]
5856
# Bandwidths for these subsets:
5957
# N = [(), (100,), (100,), (100,), (100,), (100,), (100,),
@@ -154,7 +152,9 @@ def TestFunction(x):
154152
Umask = np.append(np.array([True]), gsis > 1e-2)
155153
U = [ads.U[i] for i in np.arange(0, len(Umask))[Umask]] # get important subsets
156154
bws = M / (math.log10(M) * (len(U) - 1)) # calculate frequencies per subset
157-
N = [math.floor(bws ** (1 / max(1, len(u))) / 2) * 2 for u in U] # distribute the frequencies evenly and make them even
155+
N = [
156+
math.floor(bws ** (1 / max(1, len(u))) / 2) * 2 for u in U
157+
] # distribute the frequencies evenly and make them even
158158
N[0] = 0
159159

160160
a = ANOVAapprox.approx(

simpleTest/exampleWavelet.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
# for 'cos' the samples have to be in [0,1]^d
1414
# for a periodic function use 'per' or wavelets 'chui2', 'chui3', 'chui4' (samples have to be in [-0.5,0.5]^d here, 'chuim' are the Chui-Wang wavelets of order m)
1515

16+
1617
def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4 + f_2,3
17-
return (
18-
2 * abs(x[0])
19-
+ abs(math.sin(math.pi * x[1] * x[2]))
20-
+ np.cos(3 + x[3])
21-
)
18+
return 2 * abs(x[0]) + abs(math.sin(math.pi * x[1] * x[2])) + np.cos(3 + x[3])
2219

2320

2421
rng = np.random.default_rng(1234)
@@ -28,10 +25,10 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
2825
##################################
2926

3027
d = 8 # dimension
31-
q = 2 # superposition dimension
28+
q = 2 # superposition dimension
3229
M = 10000 # number of used evaluation points to train the model
3330
M_test = 10000 # number of used evaluation points to test the accuracity the model
34-
N = [5,2] # number of parameters, should be vector of length q:
31+
N = [5, 2] # number of parameters, should be vector of length q:
3532
# for wavelets the total number of parameters scales exponentially, i.e.:
3633
# for q = 1 and N = [N1] the total number of parameters scales like ~O(d*2^N1)
3734
# for q = 2 and N = [N1 , N2] the total number of parameters scales like ~O(d*2^N1) + O(d^2 * N2*2^N2)
@@ -42,18 +39,28 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
4239
## Generation of the data ##
4340
############################
4441

45-
if basis =="chui2" or basis =="chui3" or basis =="chui4" or basis =="per":
46-
X = rng.random((M, d)) -0.5 # for perioidic approximation samples have to be in [-0.5,0.5]^d
42+
if basis == "chui2" or basis == "chui3" or basis == "chui4" or basis == "per":
43+
X = (
44+
rng.random((M, d)) - 0.5
45+
) # for perioidic approximation samples have to be in [-0.5,0.5]^d
4746
elif basis == "cos":
4847
X = rng.random((M, d))
4948
y = np.array(
5049
[TestFunction(X[i, :].T) for i in range(M)]
5150
) # evaluate the function at these points
5251

53-
if basis == "chui1" or basis =="chui2" or basis =="chui3" or basis =="chui4" or basis =="per":
54-
X_test = rng.random((M_test, d)) - 0.5 # for perioidic approximation samples have to be in [-0.5,0.5]^d
52+
if (
53+
basis == "chui1"
54+
or basis == "chui2"
55+
or basis == "chui3"
56+
or basis == "chui4"
57+
or basis == "per"
58+
):
59+
X_test = (
60+
rng.random((M_test, d)) - 0.5
61+
) # for perioidic approximation samples have to be in [-0.5,0.5]^d
5562
elif basis == "cos":
56-
X_test = rng.random((M_test, d))
63+
X_test = rng.random((M_test, d))
5764
y_test = np.array(
5865
[TestFunction(X_test[i, :].T) for i in range(M_test)]
5966
) # the same for the test points
@@ -74,8 +81,10 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
7481
#######################
7582

7683
### Do sensitivity analysis ####
77-
gsis = ANOVAapprox.get_GSI(anova_model,0.0) #calculates indices for importance of terms (gsis is vector, with indices belonging to terms in anova_model.U)
78-
#gsis_as_dict = ANOVAapprox.get_GSI(anova_model,0.0,dict=true)
84+
gsis = ANOVAapprox.get_GSI(
85+
anova_model, 0.0
86+
) # calculates indices for importance of terms (gsis is vector, with indices belonging to terms in anova_model.U)
87+
# gsis_as_dict = ANOVAapprox.get_GSI(anova_model,0.0,dict=true)
7988

8089
y_min_calc = 10 ** (np.min(np.log10(gsis)) - 0.5)
8190
label = list(anova_model.U[1:])
@@ -104,22 +113,22 @@ def TestFunction(x): # this function is of the form f_0 + f_1 + f_2 + f_3 + f_4
104113
################################
105114

106115
### error analysis ###
107-
mse_train = ANOVAapprox.get_mse(anova_model,lam=0.0)
108-
mse_test = ANOVAapprox.get_mse(anova_model,X_test,y_test, lam=0.0)
116+
mse_train = ANOVAapprox.get_mse(anova_model, lam=0.0)
117+
mse_test = ANOVAapprox.get_mse(anova_model, X_test, y_test, lam=0.0)
109118

110-
print("MSE on test points: " + str(mse_test))
119+
print("MSE on test points: " + str(mse_test))
111120

112121
################################################
113122
## Approximation with better suited index set ##
114123
################################################
115124

116125
U = ANOVAapprox.get_ActiveSet(anova_model, [0.01, 0.01], lam=0.0)
117-
print("Found index-set U: " + str(U) )
118-
anova_model = ANOVAapprox.approx(X, y, U=U, N=[i+2 for i in N] , basis=basis) # increase number of paramers in N for the important terms
126+
print("Found index-set U: " + str(U))
127+
anova_model = ANOVAapprox.approx(
128+
X, y, U=U, N=[i + 2 for i in N], basis=basis
129+
) # increase number of paramers in N for the important terms
119130
anova_model.approximate(lam=lambdas)
120131
print("Total number of used parameters = " + str(len(anova_model.fc[lambdas[0]].vec())))
121-
mse_train = ANOVAapprox.get_mse(anova_model,lam=0.0)
122-
mse_test = ANOVAapprox.get_mse(anova_model,X_test,y_test, lam=0.0)
123-
print("MSE on test points after ANOVA truncation: " + str(mse_test))
124-
125-
132+
mse_train = ANOVAapprox.get_mse(anova_model, lam=0.0)
133+
mse_test = ANOVAapprox.get_mse(anova_model, X_test, y_test, lam=0.0)
134+
print("MSE on test points after ANOVA truncation: " + str(mse_test))

src/pyANOVAapprox/approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#from pyGroupedTransforms.GroupedTransform import * # TODO: Kann wahrscheinlich weg sobald in pyGroupedTransform GreoupedTransform exportiert wird
1+
# from pyGroupedTransforms.GroupedTransform import * # TODO: Kann wahrscheinlich weg sobald in pyGroupedTransform GreoupedTransform exportiert wird
22

33
from pyANOVAapprox import *
44
from pyANOVAapprox.fista import *

src/pyANOVAapprox/fista.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#from pyGroupedTransforms.GroupedCoefficients import * # TODO: Kann wahrscheinlich weg sobald in pyGroupedTransform GreoupedTransform exportiert wird
1+
# from pyGroupedTransforms.GroupedCoefficients import * # TODO: Kann wahrscheinlich weg sobald in pyGroupedTransform GreoupedTransform exportiert wird
22

33
from pyANOVAapprox import *
44

tests/wav_lsqr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
sys.path.insert(0, src_aa)
66

77
import numpy as np
8-
from TestFunctionPeriodic import *
98
from pyGroupedTransforms import *
9+
from TestFunctionPeriodic import *
1010

1111
import pyANOVAapprox as ANOVAapprox
1212

@@ -26,8 +26,8 @@
2626
ads = ANOVAapprox.approx(X, y, ds=ds, N=bw, basis="chui2")
2727
ads.approximate(lam=lambdas, solver="lsqr")
2828

29-
print( "AR: "+ str(sum(ANOVAapprox.get_AttributeRanking(ads, 0.0))) )
30-
assert abs( sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1 ) < 0.0001
29+
print("AR: " + str(sum(ANOVAapprox.get_AttributeRanking(ads, 0.0))))
30+
assert abs(sum(ANOVAapprox.get_AttributeRanking(ads, 0.0)) - 1) < 0.0001
3131

3232
bw = ANOVAapprox.get_orderDependentBW(AS, [4, 4])
3333
aU = ANOVAapprox.approx(X, y, U=AS, N=bw, basis="chui2")
@@ -45,6 +45,6 @@
4545
print("l2 rand U: ", err_l2_rand_U)
4646

4747
assert err_l2_ds < 0.01
48-
assert err_l2_U < 0.01 # maybe restrict to 0.005
48+
assert err_l2_U < 0.01 # maybe restrict to 0.005
4949
assert err_l2_rand_ds < 0.01
5050
assert err_l2_rand_U < 0.01

0 commit comments

Comments
 (0)