Skip to content

Commit 4a4aaf4

Browse files
committed
Eliminate zeros in sparse operators to save memory
1 parent 4e2ce4d commit 4a4aaf4

File tree

3 files changed

+33
-63
lines changed

3 files changed

+33
-63
lines changed

pySDC/helpers/spectral_helper.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11351135

11361136
ndim = len(self.axes)
11371137
if ndim == 1:
1138-
return self.sparse_lib.csc_matrix(BC)
1138+
mat = self.sparse_lib.csc_matrix(BC)
11391139
elif ndim == 2:
11401140
axis2 = (axis + 1) % ndim
11411141

@@ -1151,8 +1151,8 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11511151
] * ndim
11521152
mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
11531153
mats[axis2] = Id
1154-
return self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1155-
if ndim == 3:
1154+
mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1155+
elif ndim == 3:
11561156
mats = [
11571157
None,
11581158
] * ndim
@@ -1170,11 +1170,13 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11701170

11711171
mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
11721172

1173-
return self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1173+
mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
11741174
else:
11751175
raise NotImplementedError(
11761176
f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
11771177
)
1178+
mat.eliminate_zeros()
1179+
return mat
11781180

11791181
def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
11801182
"""
@@ -1192,6 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw
11921194
scalar (bool): Put the BC in all space positions in the other direction
11931195
"""
11941196
_BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1197+
_BC.eliminate_zeros()
11951198
self.BC_mat[self.index(equation)][self.index(component)] -= _BC
11961199

11971200
if scalar:
@@ -1375,7 +1378,7 @@ def put_BCs_in_rhs(self, rhs):
13751378

13761379
return rhs
13771380

1378-
def add_equation_lhs(self, A, equation, relations, diag=False):
1381+
def add_equation_lhs(self, A, equation, relations):
13791382
"""
13801383
Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
13811384
that you will convert to an operator later.
@@ -1410,16 +1413,11 @@ def add_equation_lhs(self, A, equation, relations, diag=False):
14101413
A (list of lists of sparse matrices): The operator to be
14111414
equation (str): The equation of the component you want this in
14121415
relations: (dict): Relations between quantities
1413-
diag (bool): Whether operator is block-diagonal
14141416
"""
14151417
for k, v in relations.items():
1416-
if diag:
1417-
assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator'
1418-
A[self.index(equation)] = v
1419-
else:
1420-
A[self.index(equation)][self.index(k)] = v
1418+
A[self.index(equation)][self.index(k)] = v
14211419

1422-
def convert_operator_matrix_to_operator(self, M, diag=False):
1420+
def convert_operator_matrix_to_operator(self, M):
14231421
"""
14241422
Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
14251423
See documentation of `SpectralHelper.add_equation_lhs` for an example.
@@ -1431,14 +1429,12 @@ def convert_operator_matrix_to_operator(self, M, diag=False):
14311429
sparse linear operator
14321430
"""
14331431
if len(self.components) == 1:
1434-
if diag:
1435-
return M[0]
1436-
else:
1437-
return M[0][0]
1438-
elif diag:
1439-
return self.sparse_lib.block_diag(M, format='csc')
1432+
op = M[0][0]
14401433
else:
1441-
return self.sparse_lib.block_array(M, format='csc')
1434+
op = self.sparse_lib.bmat(M, format='csc')
1435+
1436+
op.eliminate_zeros()
1437+
return op
14421438

14431439
def get_wavenumbers(self):
14441440
"""
@@ -1792,7 +1788,7 @@ def expand_matrix_ND(self, matrix, aligned):
17921788
ndim = len(axes) + 1
17931789

17941790
if ndim == 1:
1795-
return matrix
1791+
mat = matrix
17961792
elif ndim == 2:
17971793
axis = axes[0]
17981794
I1D = sp.eye(self.axes[axis].N)
@@ -1801,7 +1797,7 @@ def expand_matrix_ND(self, matrix, aligned):
18011797
mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
18021798
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
18031799

1804-
return sp.kron(*mats)
1800+
mat = sp.kron(*mats)
18051801
elif ndim == 3:
18061802

18071803
mats = [None] * ndim
@@ -1810,11 +1806,15 @@ def expand_matrix_ND(self, matrix, aligned):
18101806
I1D = sp.eye(self.axes[axis].N)
18111807
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
18121808

1813-
return sp.kron(mats[0], sp.kron(*mats[1:]))
1809+
mat = sp.kron(mats[0], sp.kron(*mats[1:]))
18141810

18151811
else:
18161812
raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
18171813

1814+
mat = mat.tocsc()
1815+
mat.eliminate_zeros()
1816+
return mat
1817+
18181818
def get_filter_matrix(self, axis, **kwargs):
18191819
"""
18201820
Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are

pySDC/implementations/problem_classes/generic_spectral.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,20 @@ def __getattr__(self, name):
138138
"""
139139
return getattr(self.spectral, name)
140140

141-
def _setup_operator(self, LHS, diag=False):
141+
def _setup_operator(self, LHS):
142142
"""
143143
Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more.
144144
145145
Args:
146146
LHS (dict): Equations to be added to the operator
147-
diag (bool): Whether operator is block-diagonal
148147
149148
Returns:
150149
sparse linear operator
151150
"""
152-
operator = self.spectral.get_empty_operator_matrix(diag=diag)
151+
operator = self.spectral.get_empty_operator_matrix()
153152
for line, equation in LHS.items():
154-
self.spectral.add_equation_lhs(operator, line, equation, diag=diag)
155-
return self.spectral.convert_operator_matrix_to_operator(operator, diag=diag)
153+
self.spectral.add_equation_lhs(operator, line, equation)
154+
return self.spectral.convert_operator_matrix_to_operator(operator)
156155

157156
def setup_L(self, LHS):
158157
"""
@@ -174,13 +173,13 @@ def setup_L(self, LHS):
174173
"""
175174
self.L = self._setup_operator(LHS)
176175

177-
def setup_M(self, LHS, diag=True):
176+
def setup_M(self, LHS):
178177
'''
179178
Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``.
180179
'''
181180
diff_index = list(LHS.keys())
182181
self.diff_mask = [me in diff_index for me in self.components]
183-
self.M = self._setup_operator(LHS, diag=diag)
182+
self.M = self._setup_operator(LHS)
184183

185184
def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True):
186185
"""
@@ -195,7 +194,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
195194

196195
Id = sp.eye(N)
197196
Pl_lhs = {comp: {comp: Id} for comp in self.components}
198-
self.Pl = self._setup_operator(Pl_lhs, diag=True)
197+
self.Pl = self._setup_operator(Pl_lhs)
199198

200199
if left_preconditioner:
201200
# reverse Kronecker product
@@ -217,7 +216,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
217216
_Pr = Id
218217

219218
Pr_lhs = {comp: {comp: _Pr} for comp in self.components}
220-
self.Pr = self._setup_operator(Pr_lhs, diag=True) @ self.Pl.T
219+
self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T
221220

222221
def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
223222
"""

pySDC/tests/test_helpers/test_spectral_helper.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,9 @@ def test_tau_method2D(nz, nx, bc_val, bc=-1, plotting=False, useMPI=False, **kwa
577577
Dxx = helper.get_differentiation_matrix(axes=(0,), p=2)
578578

579579
# generate operator
580-
diag = True
581-
_A = helper.get_empty_operator_matrix(diag=diag)
582-
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx}, diag=diag)
583-
A = helper.convert_operator_matrix_to_operator(_A, diag=diag)
580+
_A = helper.get_empty_operator_matrix()
581+
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx})
582+
A = helper.convert_operator_matrix_to_operator(_A)
584583

585584
# prepare system to solve
586585
A = helper.put_BCs_in_matrix(A)
@@ -841,34 +840,6 @@ def function():
841840
assert track[0] == 0, "possible memory leak with the @cache"
842841

843842

844-
@pytest.mark.base
845-
def test_block_diagonal_operators(N=16):
846-
from pySDC.helpers.spectral_helper import SpectralHelper
847-
import numpy as np
848-
849-
helper = SpectralHelper(comm=None, debug=True)
850-
helper.add_axis('fft', N=N)
851-
helper.add_axis('cheby', N=N)
852-
helper.add_component(['u', 'v'])
853-
helper.setup_fft()
854-
855-
# generate matrices
856-
Dz = helper.get_differentiation_matrix(axes=(1,))
857-
Dx = helper.get_differentiation_matrix(axes=(0,))
858-
859-
def get_operator(diag):
860-
_A = helper.get_empty_operator_matrix(diag=diag)
861-
helper.add_equation_lhs(_A, 'u', {'u': Dx}, diag=diag)
862-
helper.add_equation_lhs(_A, 'v', {'v': Dz}, diag=diag)
863-
return helper.convert_operator_matrix_to_operator(_A, diag=diag)
864-
865-
AD = get_operator(True)
866-
A = get_operator(False)
867-
868-
assert np.allclose(A.toarray(), AD.toarray()), 'Operators don\'t match'
869-
assert A.data.nbytes > AD.data.nbytes, 'Block diagonal operator did not conserve memory over general operator'
870-
871-
872843
if __name__ == '__main__':
873844
str_to_bool = lambda me: False if me == 'False' else True
874845
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))

0 commit comments

Comments
 (0)