diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index 105add7374..c37458b25f 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -1135,7 +1135,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): ndim = len(self.axes) if ndim == 1: - return self.sparse_lib.csc_matrix(BC) + mat = self.sparse_lib.csc_matrix(BC) elif ndim == 2: axis2 = (axis + 1) % ndim @@ -1151,8 +1151,8 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): ] * ndim mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) mats[axis2] = Id - return self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) - if ndim == 3: + mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) + elif ndim == 3: mats = [ None, ] * ndim @@ -1170,11 +1170,13 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) - return self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) + mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) else: raise NotImplementedError( f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' ) + mat = self.eliminate_zeros(mat) + return mat def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): """ @@ -1192,6 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw scalar (bool): Put the BC in all space positions in the other direction """ _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) + _BC = self.eliminate_zeros(_BC) self.BC_mat[self.index(equation)][self.index(component)] -= _BC if scalar: @@ -1375,7 +1378,7 @@ def put_BCs_in_rhs(self, rhs): return rhs - def add_equation_lhs(self, A, equation, relations, diag=False): + def add_equation_lhs(self, A, equation, relations): """ Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices that you will convert to an operator later. @@ -1410,16 +1413,31 @@ def add_equation_lhs(self, A, equation, relations, diag=False): A (list of lists of sparse matrices): The operator to be equation (str): The equation of the component you want this in relations: (dict): Relations between quantities - diag (bool): Whether operator is block-diagonal """ for k, v in relations.items(): - if diag: - assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator' - A[self.index(equation)] = v - else: - A[self.index(equation)][self.index(k)] = v + A[self.index(equation)][self.index(k)] = v - def convert_operator_matrix_to_operator(self, M, diag=False): + def eliminate_zeros(self, A): + """ + Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat. + Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`. + Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU. + + Args: + A: sparse matrix to be pruned + + Returns: + CSC sparse matrix + """ + if self.useGPU: + A = A.get() + A = A.tocsc() + A.eliminate_zeros() + if self.useGPU: + A = self.sparse_lib.csc_matrix(A) + return A + + def convert_operator_matrix_to_operator(self, M): """ Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. See documentation of `SpectralHelper.add_equation_lhs` for an example. @@ -1431,14 +1449,12 @@ def convert_operator_matrix_to_operator(self, M, diag=False): sparse linear operator """ if len(self.components) == 1: - if diag: - return M[0] - else: - return M[0][0] - elif diag: - return self.sparse_lib.block_diag(M, format='csc') + op = M[0][0] else: - return self.sparse_lib.block_array(M, format='csc') + op = self.sparse_lib.bmat(M, format='csc') + + op = self.eliminate_zeros(op) + return op def get_wavenumbers(self): """ @@ -1792,7 +1808,7 @@ def expand_matrix_ND(self, matrix, aligned): ndim = len(axes) + 1 if ndim == 1: - return matrix + mat = matrix elif ndim == 2: axis = axes[0] I1D = sp.eye(self.axes[axis].N) @@ -1801,7 +1817,7 @@ def expand_matrix_ND(self, matrix, aligned): mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) - return sp.kron(*mats) + mat = sp.kron(*mats) elif ndim == 3: mats = [None] * ndim @@ -1810,11 +1826,14 @@ def expand_matrix_ND(self, matrix, aligned): I1D = sp.eye(self.axes[axis].N) mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) - return sp.kron(mats[0], sp.kron(*mats[1:])) + mat = sp.kron(mats[0], sp.kron(*mats[1:])) else: raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') + mat = self.eliminate_zeros(mat) + return mat + def get_filter_matrix(self, axis, **kwargs): """ Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are diff --git a/pySDC/implementations/problem_classes/RayleighBenard.py b/pySDC/implementations/problem_classes/RayleighBenard.py index 504c036576..6c60ed5e4b 100644 --- a/pySDC/implementations/problem_classes/RayleighBenard.py +++ b/pySDC/implementations/problem_classes/RayleighBenard.py @@ -213,12 +213,8 @@ def eval_f(self, u, *args, **kwargs): # start by computing derivatives if not hasattr(self, '_Dx_expanded') or not hasattr(self, '_Dz_expanded'): - self._Dx_expanded = self._setup_operator( - {'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}}, diag=True - ) - self._Dz_expanded = self._setup_operator( - {'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}}, diag=True - ) + self._Dx_expanded = self._setup_operator({'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}}) + self._Dz_expanded = self._setup_operator({'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}}) Dx_u_hat = (self._Dx_expanded @ u_hat.flatten()).reshape(u_hat.shape) Dz_u_hat = (self._Dz_expanded @ u_hat.flatten()).reshape(u_hat.shape) diff --git a/pySDC/implementations/problem_classes/generic_spectral.py b/pySDC/implementations/problem_classes/generic_spectral.py index b8fec120f8..e78c7b0847 100644 --- a/pySDC/implementations/problem_classes/generic_spectral.py +++ b/pySDC/implementations/problem_classes/generic_spectral.py @@ -138,21 +138,20 @@ def __getattr__(self, name): """ return getattr(self.spectral, name) - def _setup_operator(self, LHS, diag=False): + def _setup_operator(self, LHS): """ Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more. Args: LHS (dict): Equations to be added to the operator - diag (bool): Whether operator is block-diagonal Returns: sparse linear operator """ - operator = self.spectral.get_empty_operator_matrix(diag=diag) + operator = self.spectral.get_empty_operator_matrix() for line, equation in LHS.items(): - self.spectral.add_equation_lhs(operator, line, equation, diag=diag) - return self.spectral.convert_operator_matrix_to_operator(operator, diag=diag) + self.spectral.add_equation_lhs(operator, line, equation) + return self.spectral.convert_operator_matrix_to_operator(operator) def setup_L(self, LHS): """ @@ -174,13 +173,13 @@ def setup_L(self, LHS): """ self.L = self._setup_operator(LHS) - def setup_M(self, LHS, diag=True): + def setup_M(self, LHS): ''' Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``. ''' diff_index = list(LHS.keys()) self.diff_mask = [me in diff_index for me in self.components] - self.M = self._setup_operator(LHS, diag=diag) + self.M = self._setup_operator(LHS) def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True): """ @@ -195,7 +194,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner Id = sp.eye(N) Pl_lhs = {comp: {comp: Id} for comp in self.components} - self.Pl = self._setup_operator(Pl_lhs, diag=True) + self.Pl = self._setup_operator(Pl_lhs) if left_preconditioner: # reverse Kronecker product @@ -217,7 +216,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner _Pr = Id Pr_lhs = {comp: {comp: _Pr} for comp in self.components} - self.Pr = self._setup_operator(Pr_lhs, diag=True) @ self.Pl.T + self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs): """ diff --git a/pySDC/tests/test_helpers/test_spectral_helper.py b/pySDC/tests/test_helpers/test_spectral_helper.py index a56fab01c2..c69dd5a953 100644 --- a/pySDC/tests/test_helpers/test_spectral_helper.py +++ b/pySDC/tests/test_helpers/test_spectral_helper.py @@ -577,10 +577,9 @@ def test_tau_method2D(nz, nx, bc_val, bc=-1, plotting=False, useMPI=False, **kwa Dxx = helper.get_differentiation_matrix(axes=(0,), p=2) # generate operator - diag = True - _A = helper.get_empty_operator_matrix(diag=diag) - helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx}, diag=diag) - A = helper.convert_operator_matrix_to_operator(_A, diag=diag) + _A = helper.get_empty_operator_matrix() + helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx}) + A = helper.convert_operator_matrix_to_operator(_A) # prepare system to solve A = helper.put_BCs_in_matrix(A) @@ -841,34 +840,6 @@ def function(): assert track[0] == 0, "possible memory leak with the @cache" -@pytest.mark.base -def test_block_diagonal_operators(N=16): - from pySDC.helpers.spectral_helper import SpectralHelper - import numpy as np - - helper = SpectralHelper(comm=None, debug=True) - helper.add_axis('fft', N=N) - helper.add_axis('cheby', N=N) - helper.add_component(['u', 'v']) - helper.setup_fft() - - # generate matrices - Dz = helper.get_differentiation_matrix(axes=(1,)) - Dx = helper.get_differentiation_matrix(axes=(0,)) - - def get_operator(diag): - _A = helper.get_empty_operator_matrix(diag=diag) - helper.add_equation_lhs(_A, 'u', {'u': Dx}, diag=diag) - helper.add_equation_lhs(_A, 'v', {'v': Dz}, diag=diag) - return helper.convert_operator_matrix_to_operator(_A, diag=diag) - - AD = get_operator(True) - A = get_operator(False) - - assert np.allclose(A.toarray(), AD.toarray()), 'Operators don\'t match' - assert A.data.nbytes > AD.data.nbytes, 'Block diagonal operator did not conserve memory over general operator' - - if __name__ == '__main__': str_to_bool = lambda me: False if me == 'False' else True str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))