Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions pySDC/helpers/spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):

ndim = len(self.axes)
if ndim == 1:
return self.sparse_lib.csc_matrix(BC)
mat = self.sparse_lib.csc_matrix(BC)
elif ndim == 2:
axis2 = (axis + 1) % ndim

Expand All @@ -1151,8 +1151,8 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
mats[axis2] = Id
return self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
if ndim == 3:
mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
elif ndim == 3:
mats = [
None,
] * ndim
Expand All @@ -1170,11 +1170,13 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):

mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)

return self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
else:
raise NotImplementedError(
f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
)
mat = self.eliminate_zeros(mat)
return mat

def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
"""
Expand All @@ -1192,6 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw
scalar (bool): Put the BC in all space positions in the other direction
"""
_BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
_BC = self.eliminate_zeros(_BC)
self.BC_mat[self.index(equation)][self.index(component)] -= _BC

if scalar:
Expand Down Expand Up @@ -1375,7 +1378,7 @@ def put_BCs_in_rhs(self, rhs):

return rhs

def add_equation_lhs(self, A, equation, relations, diag=False):
def add_equation_lhs(self, A, equation, relations):
"""
Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
that you will convert to an operator later.
Expand Down Expand Up @@ -1410,16 +1413,31 @@ def add_equation_lhs(self, A, equation, relations, diag=False):
A (list of lists of sparse matrices): The operator to be
equation (str): The equation of the component you want this in
relations: (dict): Relations between quantities
diag (bool): Whether operator is block-diagonal
"""
for k, v in relations.items():
if diag:
assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator'
A[self.index(equation)] = v
else:
A[self.index(equation)][self.index(k)] = v
A[self.index(equation)][self.index(k)] = v

def convert_operator_matrix_to_operator(self, M, diag=False):
def eliminate_zeros(self, A):
"""
Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.

Args:
A: sparse matrix to be pruned

Returns:
CSC sparse matrix
"""
if self.useGPU:
A = A.get()
A = A.tocsc()
A.eliminate_zeros()
if self.useGPU:
A = self.sparse_lib.csc_matrix(A)
return A

def convert_operator_matrix_to_operator(self, M):
"""
Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
See documentation of `SpectralHelper.add_equation_lhs` for an example.
Expand All @@ -1431,14 +1449,12 @@ def convert_operator_matrix_to_operator(self, M, diag=False):
sparse linear operator
"""
if len(self.components) == 1:
if diag:
return M[0]
else:
return M[0][0]
elif diag:
return self.sparse_lib.block_diag(M, format='csc')
op = M[0][0]
else:
return self.sparse_lib.block_array(M, format='csc')
op = self.sparse_lib.bmat(M, format='csc')

op = self.eliminate_zeros(op)
return op

def get_wavenumbers(self):
"""
Expand Down Expand Up @@ -1792,7 +1808,7 @@ def expand_matrix_ND(self, matrix, aligned):
ndim = len(axes) + 1

if ndim == 1:
return matrix
mat = matrix
elif ndim == 2:
axis = axes[0]
I1D = sp.eye(self.axes[axis].N)
Expand All @@ -1801,7 +1817,7 @@ def expand_matrix_ND(self, matrix, aligned):
mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)

return sp.kron(*mats)
mat = sp.kron(*mats)
elif ndim == 3:

mats = [None] * ndim
Expand All @@ -1810,11 +1826,14 @@ def expand_matrix_ND(self, matrix, aligned):
I1D = sp.eye(self.axes[axis].N)
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)

return sp.kron(mats[0], sp.kron(*mats[1:]))
mat = sp.kron(mats[0], sp.kron(*mats[1:]))

else:
raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')

mat = self.eliminate_zeros(mat)
return mat

def get_filter_matrix(self, axis, **kwargs):
"""
Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
Expand Down
8 changes: 2 additions & 6 deletions pySDC/implementations/problem_classes/RayleighBenard.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,8 @@ def eval_f(self, u, *args, **kwargs):

# start by computing derivatives
if not hasattr(self, '_Dx_expanded') or not hasattr(self, '_Dz_expanded'):
self._Dx_expanded = self._setup_operator(
{'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}}, diag=True
)
self._Dz_expanded = self._setup_operator(
{'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}}, diag=True
)
self._Dx_expanded = self._setup_operator({'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}})
self._Dz_expanded = self._setup_operator({'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}})
Dx_u_hat = (self._Dx_expanded @ u_hat.flatten()).reshape(u_hat.shape)
Dz_u_hat = (self._Dz_expanded @ u_hat.flatten()).reshape(u_hat.shape)

Expand Down
17 changes: 8 additions & 9 deletions pySDC/implementations/problem_classes/generic_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,20 @@ def __getattr__(self, name):
"""
return getattr(self.spectral, name)

def _setup_operator(self, LHS, diag=False):
def _setup_operator(self, LHS):
"""
Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more.

Args:
LHS (dict): Equations to be added to the operator
diag (bool): Whether operator is block-diagonal

Returns:
sparse linear operator
"""
operator = self.spectral.get_empty_operator_matrix(diag=diag)
operator = self.spectral.get_empty_operator_matrix()
for line, equation in LHS.items():
self.spectral.add_equation_lhs(operator, line, equation, diag=diag)
return self.spectral.convert_operator_matrix_to_operator(operator, diag=diag)
self.spectral.add_equation_lhs(operator, line, equation)
return self.spectral.convert_operator_matrix_to_operator(operator)

def setup_L(self, LHS):
"""
Expand All @@ -174,13 +173,13 @@ def setup_L(self, LHS):
"""
self.L = self._setup_operator(LHS)

def setup_M(self, LHS, diag=True):
def setup_M(self, LHS):
'''
Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``.
'''
diff_index = list(LHS.keys())
self.diff_mask = [me in diff_index for me in self.components]
self.M = self._setup_operator(LHS, diag=diag)
self.M = self._setup_operator(LHS)

def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True):
"""
Expand All @@ -195,7 +194,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner

Id = sp.eye(N)
Pl_lhs = {comp: {comp: Id} for comp in self.components}
self.Pl = self._setup_operator(Pl_lhs, diag=True)
self.Pl = self._setup_operator(Pl_lhs)

if left_preconditioner:
# reverse Kronecker product
Expand All @@ -217,7 +216,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
_Pr = Id

Pr_lhs = {comp: {comp: _Pr} for comp in self.components}
self.Pr = self._setup_operator(Pr_lhs, diag=True) @ self.Pl.T
self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T

def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
"""
Expand Down
35 changes: 3 additions & 32 deletions pySDC/tests/test_helpers/test_spectral_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,9 @@ def test_tau_method2D(nz, nx, bc_val, bc=-1, plotting=False, useMPI=False, **kwa
Dxx = helper.get_differentiation_matrix(axes=(0,), p=2)

# generate operator
diag = True
_A = helper.get_empty_operator_matrix(diag=diag)
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx}, diag=diag)
A = helper.convert_operator_matrix_to_operator(_A, diag=diag)
_A = helper.get_empty_operator_matrix()
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx})
A = helper.convert_operator_matrix_to_operator(_A)

# prepare system to solve
A = helper.put_BCs_in_matrix(A)
Expand Down Expand Up @@ -841,34 +840,6 @@ def function():
assert track[0] == 0, "possible memory leak with the @cache"


@pytest.mark.base
def test_block_diagonal_operators(N=16):
from pySDC.helpers.spectral_helper import SpectralHelper
import numpy as np

helper = SpectralHelper(comm=None, debug=True)
helper.add_axis('fft', N=N)
helper.add_axis('cheby', N=N)
helper.add_component(['u', 'v'])
helper.setup_fft()

# generate matrices
Dz = helper.get_differentiation_matrix(axes=(1,))
Dx = helper.get_differentiation_matrix(axes=(0,))

def get_operator(diag):
_A = helper.get_empty_operator_matrix(diag=diag)
helper.add_equation_lhs(_A, 'u', {'u': Dx}, diag=diag)
helper.add_equation_lhs(_A, 'v', {'v': Dz}, diag=diag)
return helper.convert_operator_matrix_to_operator(_A, diag=diag)

AD = get_operator(True)
A = get_operator(False)

assert np.allclose(A.toarray(), AD.toarray()), 'Operators don\'t match'
assert A.data.nbytes > AD.data.nbytes, 'Block diagonal operator did not conserve memory over general operator'


if __name__ == '__main__':
str_to_bool = lambda me: False if me == 'False' else True
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))
Expand Down