Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 38 additions & 116 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,26 @@ def get_local_slice_of_1D_matrix(self, M, axis):
"""
return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]

def expand_matrix_ND(self, matrix, aligned):
sp = self.sparse_lib
axes = np.delete(np.arange(self.ndim), aligned)
ndim = len(axes) + 1

if ndim == 1:
return matrix
elif ndim == 2:
axis = axes[0]
I1D = sp.eye(self.axes[axis].N)

mats = [None] * ndim
mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)

return sp.kron(*mats)

else:
raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')

def get_filter_matrix(self, axis, **kwargs):
"""
Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
Expand All @@ -1878,31 +1898,10 @@ def get_differentiation_matrix(self, axes, **kwargs):
Returns:
sparse differentiation matrix
"""
sp = self.sparse_lib
ndim = self.ndim

if ndim == 1:
D = self.axes[0].get_differentiation_matrix(**kwargs)
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
D1D = self.axes[axis].get_differentiation_matrix(**kwargs)

if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()

mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)

if axis == axes[0]:
D = sp.kron(*mats)
else:
D = D @ sp.kron(*mats)
else:
raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
for axis in axes[1:]:
_D = self.axes[axis].get_differentiation_matrix(**kwargs)
D = D @ self.expand_matrix_ND(_D, axis)

return D

Expand All @@ -1916,31 +1915,10 @@ def get_integration_matrix(self, axes):
Returns:
sparse integration matrix
"""
sp = self.sparse_lib
ndim = len(self.axes)

if ndim == 1:
S = self.axes[0].get_integration_matrix()
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
S1D = self.axes[axis].get_integration_matrix()

if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()

mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)

if axis == axes[0]:
S = sp.kron(*mats)
else:
S = S @ sp.kron(*mats)
else:
raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
for axis in axes[1:]:
_S = self.axes[axis].get_integration_matrix()
S = S @ self.expand_matrix_ND(_S, axis)

return S

Expand All @@ -1951,27 +1929,10 @@ def get_Id(self):
Returns:
sparse identity matrix
"""
sp = self.sparse_lib
ndim = self.ndim
I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)

if ndim == 1:
I = self.axes[0].get_Id()
elif ndim == 2:
for axis in range(ndim):
axis2 = (axis + 1) % ndim
I1D = self.axes[axis].get_Id()

I1D2 = sp.eye(self.axes[axis2].N)

mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)

I = I @ sp.kron(*mats)
else:
raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')

I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
for axis in range(1, self.ndim):
_I = self.axes[axis].get_Id()
I = I @ self.expand_matrix_ND(_I, axis)
return I

def get_Dirichlet_recombination_matrix(self, axis=-1):
Expand All @@ -1984,26 +1945,8 @@ def get_Dirichlet_recombination_matrix(self, axis=-1):
Returns:
sparse matrix
"""
sp = self.sparse_lib
ndim = len(self.axes)

if ndim == 1:
C = self.axes[0].get_Dirichlet_recombination_matrix()
elif ndim == 2:
axis2 = (axis + 1) % ndim
C1D = self.axes[axis].get_Dirichlet_recombination_matrix()

I1D = self.axes[axis2].get_Id()

mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)

C = sp.kron(*mats)
else:
raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')

return C
C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
return self.expand_matrix_ND(C1D, axis)

def get_basis_change_matrix(self, axes=None, **kwargs):
"""
Expand All @@ -2018,30 +1961,9 @@ def get_basis_change_matrix(self, axes=None, **kwargs):
"""
axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes

sp = self.sparse_lib
ndim = len(self.axes)

if ndim == 1:
C = self.axes[0].get_basis_change_matrix(**kwargs)
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
C1D = self.axes[axis].get_basis_change_matrix(**kwargs)

if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()

mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)

if axis == axes[0]:
C = sp.kron(*mats)
else:
C = C @ sp.kron(*mats)
else:
raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
for axis in axes[1:]:
_C = self.axes[axis].get_basis_change_matrix(**kwargs)
C = C @ self.expand_matrix_ND(_C, axis)

return C