Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def get_differentiation_matrix(self):
def get_integration_matrix(self):
raise NotImplementedError()

def get_integration_weights(self):
"""Weights for integration across entire domain"""
raise NotImplementedError()

def get_wavenumbers(self):
"""
Get the grid in spectral space
Expand Down Expand Up @@ -379,6 +383,16 @@ def get_integration_matrix(self, lbnd=0):
raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
return S

def get_integration_weights(self):
"""Weights for integration across entire domain"""
n = self.xp.arange(self.N, dtype=float)

weights = (-1) ** n + 1
weights[2:] /= 1 - (n**2)[2:]

weights /= 2 / self.L
return weights

def get_differentiation_matrix(self, p=1):
'''
Keep in mind that the T2T differentiation matrix is dense.
Expand Down Expand Up @@ -808,6 +822,12 @@ def get_integration_matrix(self, p=1):
k[0] = 1j * self.L
return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)

def get_integration_weights(self):
"""Weights for integration across entire domain"""
weights = self.xp.zeros(self.N)
weights[0] = self.L / self.N
return weights

def get_plan(self, u, forward, *args, **kwargs):
if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
if 'axes' in kwargs.keys():
Expand Down
29 changes: 29 additions & 0 deletions pySDC/tests/test_helpers/test_spectral_helper_1d_chebychev.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ def test_integration_matrix(N):
assert np.allclose(exact.coef[:-1], du)


@pytest.mark.base
@pytest.mark.parametrize('x0', [-1, 0])
@pytest.mark.parametrize('x1', [0.789, 1])
@pytest.mark.parametrize('N', [4, 7])
def test_integral_whole_interval(x0, x1, N):
import numpy as np
from pySDC.helpers.spectral_helper import ChebychevHelper
from qmat.lagrange import LagrangeApproximation

cheby = ChebychevHelper(N, x0=x0, x1=x1)
x = cheby.get_1dgrid()

coeffs = np.random.random(N)
coeffs[-1] = 0

u_hat = coeffs
u = cheby.itransform(u_hat)

weights = cheby.get_integration_weights()
integral = weights @ u_hat

# generate a reference solution with qmat
lag = LagrangeApproximation(points=x)
Q = lag.getIntegrationMatrix(intervals=[(x0, x1)])
integral_ref = (Q @ u)[0]

assert np.isclose(integral, integral_ref)


@pytest.mark.base
@pytest.mark.parametrize('N', [4])
@pytest.mark.parametrize('d', [1, 2, 3])
Expand Down
33 changes: 32 additions & 1 deletion pySDC/tests/test_helpers/test_spectral_helper_1d_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ def test_integration_matrix(N, plot=False):
assert np.allclose(expect, Du)


@pytest.mark.base
@pytest.mark.parametrize('x0', [-1, 0])
@pytest.mark.parametrize('x1', [0.789, 1])
@pytest.mark.parametrize('N', [32, 45])
def test_integral_whole_interval(x0, x1, N):
import numpy as np
from pySDC.helpers.spectral_helper import FFTHelper
from qmat.lagrange import LagrangeApproximation

helper = FFTHelper(N, x0=x0, x1=x1)
x = helper.get_1dgrid()

u = np.zeros_like(x)

num_coef = N // 2 - 1
coeffs = np.random.random((2, N))
u += coeffs[0, 0]
for i in range(1, num_coef + 1):
u += coeffs[0, i] * np.sin(2 * np.pi * i * x / helper.L)
u += coeffs[1, i] * np.cos(2 * np.pi * i * x / helper.L)

u_hat = helper.transform(u)

weights = helper.get_integration_weights()
integral = weights @ u_hat
integral_ref = coeffs[0, 0] * helper.L

assert np.isclose(integral, integral_ref, atol=1e-7), abs(integral_ref - integral)


@pytest.mark.base
@pytest.mark.parametrize('N', [4, 32])
@pytest.mark.parametrize('v', [0, 4.78])
Expand Down Expand Up @@ -141,4 +171,5 @@ def test_tau_method(N, v):
# test_tau_method(6, 1)
# test_transform(True)
# test_transform(False)
test_transform_cupy(4)
# test_transform_cupy(4)
test_integral_whole_interval(0, 2, 90)
Loading