Skip to content

Commit 3cb9a4e

Browse files
authored
added bandwidth detection
Develop
2 parents 268204d + 38e6dca commit 3cb9a4e

21 files changed

+1107
-365
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
pip install pyGroupedTransforms
2828
pip install numpy
2929
pip install scipy
30+
pip install matplotlib
3031
- name: Run tests
3132
run: |
3233
python -m tests.run_tests

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dist/
66
*.py[cod]
77
__pycache__/
88
venv/
9+
venv2/
910

1011
bin/
1112
/lib64/
@@ -16,3 +17,4 @@ src/pyNFFT3/lib/AVX2/libgomp-1.dll
1617
src/pyNFFT3/lib/AVX2/libgcc_s_seh-1.dll
1718
src/pyNFFT3/lib/AVX2/libwinpthread-1.dll
1819
simpleTest/venv/
20+
bandwidth_detection.py

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ Requirements
5353
- pyGroupedTransforms 0.1.0 or greater
5454
- NumPy 2.0.0 or greater
5555
- SciPy 1.16.0 or greater
56+
- Matplotlib 3.5 or greater

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "pyANOVAapprox"
7-
version = "1.0.0"
7+
version = "2.0.0"
88
authors = [
99
{ name="Felix Wirth", email="fwi012001@gmail.com" },
1010
]
@@ -15,9 +15,10 @@ description = "Approximation Package for High-Dimensional Functions"
1515
readme = "README.md"
1616
requires-python = ">=3.9"
1717
dependencies = [
18-
"pyGroupedTransforms>=0.1.0",
18+
"pyGroupedTransforms>=1.1.0",
1919
"numpy>=2.0.0",
2020
"scipy>=1.16.0",
21+
"matplotlib>=3.5"
2122
]
2223

2324
[project.urls]

simpleTest/TestFunctionPeriodic.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# In[ ]:
5+
6+
AS = [(), (0,), (1,), (2,), (3,), (4,), (5,), (0, 2), (1, 4), (3, 5)]
7+
8+
import numpy as np
9+
10+
# Coefficients C[0] = C₁, C[1] = C₂, C[2] = C₃
11+
# (Julia is 1-based; Python is 0-based)
12+
C = np.array([np.sqrt(0.75), np.sqrt(315 / 604), np.sqrt(277200 / 655177)])
13+
14+
# --- B-spline definitions ---
15+
16+
17+
def b_spline_2(x):
18+
C_2 = C[0]
19+
if 0.0 <= x < 0.5:
20+
return C_2 * 4 * x
21+
elif 0.5 <= x < 1.0:
22+
return C_2 * 4 * (1 - x)
23+
else:
24+
raise ValueError("B-spline 2: x out of range [0,1)")
25+
26+
27+
def b_spline_4(x):
28+
C_4 = C[1]
29+
if 0.0 <= x < 0.25:
30+
return C_4 * (128 / 3) * x**3
31+
elif 0.25 <= x < 0.5:
32+
return C_4 * (8 / 3 - 32 * x + 128 * x**2 - 128 * x**3)
33+
elif 0.5 <= x < 0.75:
34+
return C_4 * (-88 / 3 - 256 * x**2 + 160 * x + 128 * x**3)
35+
elif 0.75 <= x < 1.0:
36+
return C_4 * (128 / 3 - 128 * x + 128 * x**2 - (128 / 3) * x**3)
37+
else:
38+
raise ValueError("B-spline 4: x out of range [0,1)")
39+
40+
41+
def b_spline_6(x):
42+
C_6 = C[2]
43+
if 0.0 <= x < 1.0 / 6:
44+
return C_6 * (1944 / 5) * x**5
45+
elif 1.0 / 6 <= x < 2.0 / 6:
46+
return C_6 * (
47+
3 / 10 - 9 * x + 108 * x**2 - 648 * x**3 + 1944 * x**4 - 1944 * x**5
48+
)
49+
elif 2.0 / 6 <= x < 0.5:
50+
return C_6 * (
51+
-237 / 10 + 351 * x - 2052 * x**2 + 5832 * x**3 - 7776 * x**4 + 3888 * x**5
52+
)
53+
elif 0.5 <= x < 4.0 / 6:
54+
return C_6 * (
55+
2193 / 10
56+
+ 7668 * x**2
57+
- 2079 * x
58+
+ 11664 * x**4
59+
- 13608 * x**3
60+
- 3888 * x**5
61+
)
62+
elif 4.0 / 6 <= x < 5.0 / 6:
63+
return C_6 * (
64+
-5487 / 10
65+
+ 3681 * x
66+
- 9612 * x**2
67+
+ 12312 * x**3
68+
- 7776 * x**4
69+
+ 1944 * x**5
70+
)
71+
elif 5.0 / 6 <= x < 1.0:
72+
return C_6 * (
73+
1944 / 5
74+
- 1944 * x
75+
+ 3888 * x**2
76+
- 3888 * x**3
77+
+ 1944 * x**4
78+
- (1944 / 5) * x**5
79+
)
80+
else:
81+
raise ValueError("B-spline 6: x out of range [0,1)")
82+
83+
84+
# --- Block structure ---
85+
86+
m1 = [0, 2] # 1-based [1, 3]
87+
m2 = [1, 4] # 2-based [2, 5]
88+
m3 = [3, 5] # 3-based [4, 6]
89+
90+
# --- Transformed function f(x) ---
91+
92+
93+
def trans(x):
94+
return x + 1 if x < 0 else x
95+
96+
97+
def f(x):
98+
x = np.asarray(x)
99+
if x.shape != (6,):
100+
raise ValueError("f(x): Input must be 6-dimensional.")
101+
if np.any(x < -0.5) or np.any(x > 0.5):
102+
raise ValueError("f(x): All entries must be in [-0.5, 0.5].")
103+
104+
xT = np.where(x < 0, x + 1, x) # vectorized 'trans'
105+
return (
106+
np.prod([b_spline_2(xT[i]) for i in m1])
107+
+ np.prod([b_spline_4(xT[i]) for i in m2])
108+
+ np.prod([b_spline_6(xT[i]) for i in m3])
109+
)
110+
111+
112+
# --- sinc and b(k, r) function ---
113+
114+
115+
def sinc(x):
116+
return 1.0 if x == 0.0 else np.sin(x) / x
117+
118+
119+
def b(k, r):
120+
idx = r // 2 - 1 # Adjust for Python 0-based indexing
121+
return C[idx] * (sinc(np.pi * k / r) ** r) * np.cos(np.pi * k)
122+
123+
124+
# --- Fourier coefficients fc(k) ---
125+
126+
127+
def fc(k):
128+
if len(k) != 6:
129+
raise ValueError("fc(k): k must be 6-dimensional.")
130+
131+
ind = np.array([int(ki != 0) for ki in k])
132+
133+
b2_block = np.sum(ind) == np.sum(ind[m1])
134+
b4_block = np.sum(ind) == np.sum(ind[m2])
135+
b6_block = np.sum(ind) == np.sum(ind[m3])
136+
137+
val = 0.0
138+
if b2_block:
139+
val += np.prod([b(k[i], 2) for i in m1])
140+
if b4_block:
141+
val += np.prod([b(k[i], 4) for i in m2])
142+
if b6_block:
143+
val += np.prod([b(k[i], 6) for i in m3])
144+
return val
145+
146+
147+
# --- Norm computation ---
148+
149+
150+
def norm():
151+
result = 3.0
152+
for i in range(1, 3): # i = 1, 2
153+
for j in range(i + 1, 4): # j = 2, 3
154+
result += 2 * b(0, 2 * i) ** 2 * b(0, 2 * j) ** 2
155+
return np.sqrt(result)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# pip install pyANOVAapprox
2+
3+
# Example for approximating an periodic function
4+
5+
import math
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
from TestFunctionPeriodic import *
10+
11+
import pyANOVAapprox as ANOVAapprox
12+
13+
14+
def TestFunction(x):
15+
return b_spline_2(x[0]) * b_spline_4(x[1]) * b_spline_6(x[2])
16+
17+
18+
rng = np.random.default_rng(1234)
19+
20+
##################################
21+
## Definition of the parameters ##
22+
##################################
23+
24+
d = 3 # dimension
25+
26+
M = 100000 # number of used evaluation points to train the model
27+
M_test = 100000 # number of used evaluation points to test the accuracity the model
28+
29+
U = [(), (0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)]
30+
31+
lambdas = np.array([0.0]) # used regularisation parameters λ
32+
33+
############################
34+
## Generation of the data ##
35+
############################
36+
37+
X = rng.random((M, d)) # construct the evaluation points for training
38+
y = np.array(
39+
[TestFunction(X[i, :].T) for i in range(M)], dtype=complex
40+
) # evaluate the function at these points
41+
X = X - 0.5
42+
X_test = rng.random((M_test, d))
43+
y_test = np.array(
44+
[TestFunction(X_test[i, :].T) for i in range(M_test)], dtype=complex
45+
) # the same for the test points
46+
X_test = X_test - 0.5
47+
48+
##########################
49+
## Do the approximation ##
50+
##########################
51+
52+
ads = ANOVAapprox.approx(X, y, U=U, basis="per")
53+
ads.autoapproximate()
54+
55+
################################
56+
## get approximation accuracy ##
57+
################################
58+
59+
# mse = ANOVAapprox.get_mse(ads) # get mse error at the given training points
60+
mse = ads.get_mse(X=X_test, y=y_test) # get mse error at the test points
61+
λ_min = min(
62+
mse, key=mse.get
63+
) # get the regularisation parameter which leads to the minimal error
64+
mse_min = mse[λ_min]
65+
66+
print("mse = " + str(mse_min))

simpleTest/exampleCheb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def TestFunction(x):
103103
ar = ANOVAapprox.get_AttributeRanking(ads, λ_min) # get the attrbute ranking
104104

105105
plt.figure()
106-
(markers, stemlines, baseline) = plt.stem(
106+
markers, stemlines, baseline = plt.stem(
107107
np.arange(1, d + 1), # x-Werte (1:d)
108108
ar, # y-Werte (ar)
109109
linefmt="C0-", # Stil der Stiele
@@ -126,7 +126,7 @@ def TestFunction(x):
126126
l = len(label)
127127
plt.figure()
128128
x_values = np.arange(1, l + 1)
129-
(markers, stemlines, baseline) = plt.stem(
129+
markers, stemlines, baseline = plt.stem(
130130
x_values, # X-Werte: 1 bis l
131131
gsis, # Y-Werte: gsis
132132
linefmt="C0-", # Stil der Stiele

simpleTest/exampleClassification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def TestFunction(x):
100100
ar = ANOVAapprox.get_AttributeRanking(ads, 0.0) # get the attrbute ranking
101101

102102
plt.figure()
103-
(markers, stemlines, baseline) = plt.stem(
103+
markers, stemlines, baseline = plt.stem(
104104
np.arange(1, d + 1), # x-Werte (1:d)
105105
ar, # y-Werte (ar)
106106
linefmt="C0-", # Stil der Stiele
@@ -123,7 +123,7 @@ def TestFunction(x):
123123
l = len(label)
124124
plt.figure()
125125
x_values = np.arange(1, l + 1)
126-
(markers, stemlines, baseline) = plt.stem(
126+
markers, stemlines, baseline = plt.stem(
127127
x_values, # X-Werte: 1 bis l
128128
gsis, # Y-Werte: gsis
129129
linefmt="C0-", # Stil der Stiele

simpleTest/exampleNonPeriodic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def TestFunction(x):
101101
ar = ANOVAapprox.get_AttributeRanking(ads, λ_min) # get the attrbute ranking
102102

103103
plt.figure()
104-
(markers, stemlines, baseline) = plt.stem(
104+
markers, stemlines, baseline = plt.stem(
105105
np.arange(1, d + 1), # x-Werte (1:d)
106106
ar, # y-Werte (ar)
107107
linefmt="C0-", # Stil der Stiele
@@ -124,7 +124,7 @@ def TestFunction(x):
124124
l = len(label)
125125
plt.figure()
126126
x_values = np.arange(1, l + 1)
127-
(markers, stemlines, baseline) = plt.stem(
127+
markers, stemlines, baseline = plt.stem(
128128
x_values, # X-Werte: 1 bis l
129129
gsis, # Y-Werte: gsis
130130
linefmt="C0-", # Stil der Stiele

simpleTest/examplePeriodic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def TestFunction(x):
9090
################################
9191

9292
# mse = ANOVAapprox.get_mse(ads) # get mse error at the given training points
93-
mse = ANOVAapprox.get_mse(ads, X_test, y_test) # get mse error at the test points
93+
mse = ads.get_mse(X=X_test, y=y_test) # get mse error at the test points
9494
λ_min = min(
9595
mse, key=mse.get
9696
) # get the regularisation parameter which leads to the minimal error
@@ -103,10 +103,10 @@ def TestFunction(x):
103103
###############################################
104104

105105

106-
ar = ANOVAapprox.get_AttributeRanking(ads, λ_min) # get the attrbute ranking
106+
ar = ads.get_AttributeRanking(lam=λ_min) # get the attrbute ranking
107107

108108
plt.figure()
109-
(markers, stemlines, baseline) = plt.stem(
109+
markers, stemlines, baseline = plt.stem(
110110
np.arange(1, d + 1), # x-Werte (1:d)
111111
ar, # y-Werte (ar)
112112
linefmt="C0-", # Stil der Stiele
@@ -124,12 +124,12 @@ def TestFunction(x):
124124
plt.show() # plot the arrtibute ranking in an logplot
125125
print("active dimensions: " + str(ar[ar > 1e-2]))
126126

127-
gsis = ANOVAapprox.get_GSI(ads, λ_min)
127+
gsis = ads.get_GSI(lam=λ_min)
128128
label = list(ads.U[1:])
129129
l = len(label)
130130
plt.figure()
131131
x_values = np.arange(1, l + 1)
132-
(markers, stemlines, baseline) = plt.stem(
132+
markers, stemlines, baseline = plt.stem(
133133
x_values, # X-Werte: 1 bis l
134134
gsis, # Y-Werte: gsis
135135
linefmt="C0-", # Stil der Stiele
@@ -169,7 +169,7 @@ def TestFunction(x):
169169
lam=lambdas, max_iter=max_iter, solver="lsqr"
170170
) # do the approximation for all specified regularisation parameters
171171

172-
mse = ANOVAapprox.get_mse(a, X_test, y_test) # get mse error at the test points
172+
mse = a.get_mse(X=X_test, y=y_test) # get mse error at the test points
173173
λ_min = min(
174174
mse, key=mse.get
175175
) # get the regularisation parameter which leads to the minimal error

0 commit comments

Comments
 (0)