Skip to content

Commit d56045a

Browse files
Reduced memory footprint of spectral discretizations via block-diagonal (#565)
sparse operators
1 parent 0e4830a commit d56045a

File tree

4 files changed

+72
-21
lines changed

4 files changed

+72
-21
lines changed

pySDC/helpers/spectral_helper.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,16 +1024,22 @@ def index(self, name):
10241024
else:
10251025
raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
10261026

1027-
def get_empty_operator_matrix(self):
1027+
def get_empty_operator_matrix(self, diag=False):
10281028
"""
10291029
Return a matrix of operators to be filled with the connections between the solution components.
10301030
1031+
Args:
1032+
diag (bool): Whether operator is block-diagonal
1033+
10311034
Returns:
10321035
list containing sparse zeros
10331036
"""
10341037
S = len(self.components)
10351038
O = self.get_Id() * 0
1036-
return [[O for _ in range(S)] for _ in range(S)]
1039+
if diag:
1040+
return [O for _ in range(S)]
1041+
else:
1042+
return [[O for _ in range(S)] for _ in range(S)]
10371043

10381044
def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
10391045
"""
@@ -1296,7 +1302,7 @@ def put_BCs_in_rhs(self, rhs):
12961302

12971303
return rhs
12981304

1299-
def add_equation_lhs(self, A, equation, relations):
1305+
def add_equation_lhs(self, A, equation, relations, diag=False):
13001306
"""
13011307
Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
13021308
that you will convert to an operator later.
@@ -1331,11 +1337,16 @@ def add_equation_lhs(self, A, equation, relations):
13311337
A (list of lists of sparse matrices): The operator to be
13321338
equation (str): The equation of the component you want this in
13331339
relations: (dict): Relations between quantities
1340+
diag (bool): Whether operator is block-diagonal
13341341
"""
13351342
for k, v in relations.items():
1336-
A[self.index(equation)][self.index(k)] = v
1343+
if diag:
1344+
assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator'
1345+
A[self.index(equation)] = v
1346+
else:
1347+
A[self.index(equation)][self.index(k)] = v
13371348

1338-
def convert_operator_matrix_to_operator(self, M):
1349+
def convert_operator_matrix_to_operator(self, M, diag=False):
13391350
"""
13401351
Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
13411352
See documentation of `SpectralHelper.add_equation_lhs` for an example.
@@ -1347,9 +1358,14 @@ def convert_operator_matrix_to_operator(self, M):
13471358
sparse linear operator
13481359
"""
13491360
if len(self.components) == 1:
1350-
return M[0][0]
1361+
if diag:
1362+
return M[0]
1363+
else:
1364+
return M[0][0]
1365+
elif diag:
1366+
return self.sparse_lib.block_diag(M, format='csc')
13511367
else:
1352-
return self.sparse_lib.bmat(M, format='csc')
1368+
return self.sparse_lib.block_array(M, format='csc')
13531369

13541370
def get_wavenumbers(self):
13551371
"""

pySDC/implementations/problem_classes/RayleighBenard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,12 @@ def eval_f(self, u, *args, **kwargs):
213213

214214
# start by computing derivatives
215215
if not hasattr(self, '_Dx_expanded') or not hasattr(self, '_Dz_expanded'):
216-
self._Dx_expanded = self._setup_operator({'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}})
217-
self._Dz_expanded = self._setup_operator({'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}})
216+
self._Dx_expanded = self._setup_operator(
217+
{'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}}, diag=True
218+
)
219+
self._Dz_expanded = self._setup_operator(
220+
{'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}}, diag=True
221+
)
218222
Dx_u_hat = (self._Dx_expanded @ u_hat.flatten()).reshape(u_hat.shape)
219223
Dz_u_hat = (self._Dz_expanded @ u_hat.flatten()).reshape(u_hat.shape)
220224

pySDC/implementations/problem_classes/generic_spectral.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,21 @@ def __getattr__(self, name):
136136
"""
137137
return getattr(self.spectral, name)
138138

139-
def _setup_operator(self, LHS):
139+
def _setup_operator(self, LHS, diag=False):
140140
"""
141141
Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more.
142142
143143
Args:
144144
LHS (dict): Equations to be added to the operator
145+
diag (bool): Whether operator is block-diagonal
145146
146147
Returns:
147148
sparse linear operator
148149
"""
149-
operator = self.spectral.get_empty_operator_matrix()
150+
operator = self.spectral.get_empty_operator_matrix(diag=diag)
150151
for line, equation in LHS.items():
151-
self.spectral.add_equation_lhs(operator, line, equation)
152-
return self.spectral.convert_operator_matrix_to_operator(operator)
152+
self.spectral.add_equation_lhs(operator, line, equation, diag=diag)
153+
return self.spectral.convert_operator_matrix_to_operator(operator, diag=diag)
153154

154155
def setup_L(self, LHS):
155156
"""
@@ -171,13 +172,13 @@ def setup_L(self, LHS):
171172
"""
172173
self.L = self._setup_operator(LHS)
173174

174-
def setup_M(self, LHS):
175+
def setup_M(self, LHS, diag=True):
175176
'''
176177
Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``.
177178
'''
178179
diff_index = list(LHS.keys())
179180
self.diff_mask = [me in diff_index for me in self.components]
180-
self.M = self._setup_operator(LHS)
181+
self.M = self._setup_operator(LHS, diag=diag)
181182

182183
def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True):
183184
"""
@@ -192,7 +193,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
192193

193194
Id = sp.eye(N)
194195
Pl_lhs = {comp: {comp: Id} for comp in self.components}
195-
self.Pl = self._setup_operator(Pl_lhs)
196+
self.Pl = self._setup_operator(Pl_lhs, diag=True)
196197

197198
if left_preconditioner:
198199
# reverse Kronecker product
@@ -214,7 +215,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
214215
_Pr = Id
215216

216217
Pr_lhs = {comp: {comp: _Pr} for comp in self.components}
217-
self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T
218+
self.Pr = self._setup_operator(Pr_lhs, diag=True) @ self.Pl.T
218219

219220
def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
220221
"""

pySDC/tests/test_helpers/test_spectral_helper.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,10 @@ def test_tau_method2D(variant, nz, nx, bc_val, bc=-1, useMPI=False, plotting=Fal
491491
Dxx = helper.get_differentiation_matrix(axes=(0,), p=2)
492492

493493
# generate operator
494-
_A = helper.get_empty_operator_matrix()
495-
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx})
496-
A = helper.convert_operator_matrix_to_operator(_A)
494+
diag = True
495+
_A = helper.get_empty_operator_matrix(diag=diag)
496+
helper.add_equation_lhs(_A, 'u', {'u': Dz - Dxx * 1e-1 - Dx}, diag=diag)
497+
A = helper.convert_operator_matrix_to_operator(_A, diag=diag)
497498

498499
# prepare system to solve
499500
A = helper.put_BCs_in_matrix(A)
@@ -608,6 +609,34 @@ def function():
608609
assert track[0] == 0, "possible memory leak with the @cache"
609610

610611

612+
@pytest.mark.base
613+
def test_block_diagonal_operators(N=16):
614+
from pySDC.helpers.spectral_helper import SpectralHelper
615+
import numpy as np
616+
617+
helper = SpectralHelper(comm=None, debug=True)
618+
helper.add_axis('fft', N=N)
619+
helper.add_axis('cheby', N=N)
620+
helper.add_component(['u', 'v'])
621+
helper.setup_fft()
622+
623+
# generate matrices
624+
Dz = helper.get_differentiation_matrix(axes=(1,))
625+
Dx = helper.get_differentiation_matrix(axes=(0,))
626+
627+
def get_operator(diag):
628+
_A = helper.get_empty_operator_matrix(diag=diag)
629+
helper.add_equation_lhs(_A, 'u', {'u': Dx}, diag=diag)
630+
helper.add_equation_lhs(_A, 'v', {'v': Dz}, diag=diag)
631+
return helper.convert_operator_matrix_to_operator(_A, diag=diag)
632+
633+
AD = get_operator(True)
634+
A = get_operator(False)
635+
636+
assert np.allclose(A.toarray(), AD.toarray()), 'Operators don\'t match'
637+
assert A.data.nbytes > AD.data.nbytes, 'Block diagonal operator did not conserve memory over general operator'
638+
639+
611640
if __name__ == '__main__':
612641
str_to_bool = lambda me: False if me == 'False' else True
613642
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))
@@ -642,9 +671,10 @@ def function():
642671
# test_differentiation_matrix2D(2**5, 2**5, 'T2U', bx='fft', bz='fft', axes=(-2, -1))
643672
# test_matrix1D(4, 'cheby', 'int')
644673
# test_tau_method(-1, 8, 99, kind='Dirichlet')
645-
test_tau_method2D('T2U', 2**8, 2**8, -2, plotting=True)
674+
# test_tau_method2D('T2U', 2**8, 2**8, -2, plotting=True)
646675
# test_filter(6, 6, (0,))
647676
# _test_transform_dealias('fft', 'cheby', (-1, -2))
677+
test_block_diagonal_operators()
648678
else:
649679
raise NotImplementedError
650680
print('done')

0 commit comments

Comments
 (0)