Skip to content

Commit f93f994

Browse files
committed
Implemented more general heterogeneous solves
1 parent 13e2109 commit f93f994

File tree

2 files changed

+56
-147
lines changed

2 files changed

+56
-147
lines changed

pySDC/implementations/problem_classes/RayleighBenard3D.py

Lines changed: 0 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -357,147 +357,3 @@ def plot(self, u, t=None, fig=None, quantity='T'): # pragma: no cover
357357
axs[1].set_ylabel(r'$z$')
358358
fig.colorbar(imT, self.cax[0])
359359
fig.colorbar(imV, self.cax[1])
360-
361-
362-
class RayleighBenard3DHeterogeneous(RayleighBenard3D):
363-
364-
def __init__(self, *args, **kwargs):
365-
super().__init__(*args, **kwargs)
366-
367-
CPU_only = ['BC_line_zero_matrix', 'BCs']
368-
both = ['Pl', 'Pr', 'L', 'M']
369-
370-
# copy matrices we need on CPU
371-
if self.useGPU:
372-
for key in CPU_only:
373-
setattr(self.spectral, key, getattr(self.spectral, key).get())
374-
375-
for key in both:
376-
setattr(self, f'{key}_CPU', getattr(self, key).get())
377-
else:
378-
for key in both:
379-
setattr(self, f'{key}_CPU', getattr(self, key))
380-
381-
def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
382-
"""
383-
Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by
384-
``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``.
385-
386-
The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is.
387-
This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs.
388-
389-
Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to
390-
zero. If you want something else, it should be easy to overload this function.
391-
"""
392-
393-
sp = self.spectral.sparse_lib
394-
395-
if self.spectral_space:
396-
rhs_hat = rhs.copy()
397-
if u0 is not None:
398-
u0_hat = u0.copy().flatten()
399-
else:
400-
u0_hat = None
401-
else:
402-
rhs_hat = self.spectral.transform(rhs)
403-
if u0 is not None:
404-
u0_hat = self.spectral.transform(u0).flatten()
405-
else:
406-
u0_hat = None
407-
408-
# apply inverse right preconditioner to initial guess
409-
if u0_hat is not None and 'direct' not in self.solver_type:
410-
if not hasattr(self, '_Pr_inv'):
411-
self._PR_inv = self.linalg.splu(self.Pr.astype(complex)).solve
412-
u0_hat[...] = self._PR_inv(u0_hat)
413-
414-
rhs_hat = (self.M @ rhs_hat.flatten()).reshape(rhs_hat.shape)
415-
rhs_hat = self.spectral.put_BCs_in_rhs_hat(rhs_hat)
416-
rhs_hat = self.Pl @ rhs_hat.flatten()
417-
418-
if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
419-
A = self.M_CPU + dt * self.L_CPU
420-
A = self.Pl_CPU @ self.spectral.put_BCs_in_matrix(A) @ self.Pr_CPU
421-
A = self.spectral.sparse_lib.csc_matrix(A)
422-
423-
# if A.shape[0] < 200e20:
424-
# import matplotlib.pyplot as plt
425-
426-
# # M = self.spectral.put_BCs_in_matrix(self.L.copy())
427-
# M = A # self.L
428-
# im = plt.spy(M)
429-
# plt.show()
430-
431-
if 'ilu' in self.solver_type.lower():
432-
if dt not in self.cached_factorizations.keys():
433-
if len(self.cached_factorizations) >= self.max_cached_factorizations:
434-
to_evict = list(self.cached_factorizations.keys())[0]
435-
self.cached_factorizations.pop(to_evict)
436-
self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
437-
iLU = self.linalg.spilu(
438-
A, **{**self.preconditioner_args, 'drop_tol': dt * self.preconditioner_args['drop_tol']}
439-
)
440-
self.cached_factorizations[dt] = self.linalg.LinearOperator(A.shape, iLU.solve)
441-
self.logger.debug(f'Cached incomplete LU factorization for {dt=:.6f}')
442-
self.work_counters['factorizations']()
443-
M = self.cached_factorizations[dt]
444-
else:
445-
M = None
446-
info = 0
447-
448-
if self.solver_type.lower() == 'cached_direct':
449-
if dt not in self.cached_factorizations.keys():
450-
if len(self.cached_factorizations) >= self.max_cached_factorizations:
451-
self.cached_factorizations.pop(list(self.cached_factorizations.keys())[0])
452-
self.logger.debug(f'Evicted matrix factorization for {dt=:.6f} from cache')
453-
self.cached_factorizations[dt] = self.spectral.linalg.factorized(A)
454-
self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
455-
self.work_counters['factorizations']()
456-
457-
_sol_hat = self.cached_factorizations[dt](rhs_hat)
458-
self.logger.debug(f'Used cached matrix factorization for {dt=:.6f}')
459-
460-
elif self.solver_type.lower() == 'direct':
461-
_sol_hat = sp.linalg.spsolve(A, rhs_hat)
462-
elif 'gmres' in self.solver_type.lower():
463-
_sol_hat, _ = sp.linalg.gmres(
464-
A,
465-
rhs_hat,
466-
x0=u0_hat,
467-
**self.solver_args,
468-
callback=self.work_counters[self.solver_type],
469-
callback_type='pr_norm',
470-
M=M,
471-
)
472-
elif self.solver_type.lower() == 'cg':
473-
_sol_hat, info = sp.linalg.cg(
474-
A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
475-
)
476-
elif 'bicgstab' in self.solver_type.lower():
477-
_sol_hat, info = self.linalg.bicgstab(
478-
A,
479-
rhs_hat,
480-
x0=u0_hat,
481-
**self.solver_args,
482-
callback=self.work_counters[self.solver_type],
483-
M=M,
484-
)
485-
else:
486-
raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')
487-
488-
if info != 0:
489-
self.logger.warn(f'{self.solver_type} not converged! {info=}')
490-
491-
sol_hat = self.spectral.u_init_forward
492-
sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)
493-
494-
if self.spectral_space:
495-
return sol_hat
496-
else:
497-
sol = self.spectral.u_init
498-
sol[:] = self.spectral.itransform(sol_hat).real
499-
500-
if self.spectral.debug:
501-
self.spectral.check_BCs(sol)
502-
503-
return sol

pySDC/implementations/problem_classes/generic_spectral.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
max_cached_factorizations=12,
6565
spectral_space=True,
6666
real_spectral_coefficients=False,
67+
heterogeneous=False,
6768
debug=False,
6869
):
6970
"""
@@ -81,6 +82,7 @@ def __init__(
8182
max_cached_factorizations (int): Number of matrix decompositions to cache before starting eviction
8283
spectral_space (bool): If yes, the solution will not be transformed back after solving and evaluating the RHS, and is expected as input in spectral space to these functions
8384
real_spectral_coefficients (bool): If yes, allow only real values in spectral space, otherwise, allow complex.
85+
heterogeneous (bool): If yes, perform memory intensive sparse matrix operations on CPU
8486
debug (bool): Make additional tests at extra computational cost
8587
"""
8688
solver_args = {} if solver_args is None else solver_args
@@ -100,6 +102,7 @@ def __init__(
100102
'comm',
101103
'spectral_space',
102104
'real_spectral_coefficients',
105+
'heterogeneous',
103106
'debug',
104107
localVars=locals(),
105108
)
@@ -126,6 +129,29 @@ def __init__(
126129

127130
self.cached_factorizations = {}
128131

132+
if self.heterogeneous:
133+
self.__heterogeneous_setup = False
134+
135+
def heterogeneous_setup(self):
136+
if self.heterogeneous and self.useGPU and not self.__heterogeneous_setup:
137+
for key in ['BC_line_zero_matrix', 'BCs']:
138+
setattr(self.spectral, key, getattr(self.spectral, key).get())
139+
140+
CPU_only = ['BC_line_zero_matrix', 'BCs']
141+
both = ['Pl', 'Pr', 'L', 'M']
142+
143+
if self.useGPU:
144+
for key in CPU_only:
145+
setattr(self.spectral, key, getattr(self.spectral, key).get())
146+
147+
for key in both:
148+
setattr(self, f'{key}_CPU', getattr(self, key).get())
149+
else:
150+
for key in both:
151+
setattr(self, f'{key}_CPU', getattr(self, key))
152+
153+
self.__heterogeneous_setup = True
154+
129155
def __getattr__(self, name):
130156
"""
131157
Pass requests on to the helper if they are not directly attributes of this class for convenience.
@@ -233,6 +259,8 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
233259

234260
sp = self.spectral.sparse_lib
235261

262+
self.heterogeneous_setup()
263+
236264
if self.spectral_space:
237265
rhs_hat = rhs.copy()
238266
if u0 is not None:
@@ -257,8 +285,19 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
257285
rhs_hat = self.Pl @ rhs_hat.flatten()
258286

259287
if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
260-
A = self.M + dt * self.L
261-
A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr
288+
if self.heterogeneous:
289+
M = self.M_CPU
290+
L = self.L_CPU
291+
Pl = self.Pl_CPU
292+
Pr = self.Pr_CPU
293+
else:
294+
M = self.M
295+
L = self.L
296+
Pl = self.Pl
297+
Pr = self.Pr
298+
299+
A = M + dt * L
300+
A = Pl @ self.spectral.put_BCs_in_matrix(A) @ Pr
262301

263302
# if A.shape[0] < 200e20:
264303
# import matplotlib.pyplot as plt
@@ -290,7 +329,21 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
290329
if len(self.cached_factorizations) >= self.max_cached_factorizations:
291330
self.cached_factorizations.pop(list(self.cached_factorizations.keys())[0])
292331
self.logger.debug(f'Evicted matrix factorization for {dt=:.6f} from cache')
293-
self.cached_factorizations[dt] = self.spectral.linalg.factorized(A)
332+
333+
if self.heterogeneous:
334+
import scipy.sparse as sp
335+
336+
cpu_decomp = sp.linalg.splu(A)
337+
if self.useGPU:
338+
from cupyx.scipy.sparse.linalg import SuperLU
339+
340+
solver = SuperLU(cpu_decomp).solve
341+
else:
342+
solver = cpu_decomp.solve
343+
else:
344+
solver = self.spectral.linalg.factorized(A)
345+
346+
self.cached_factorizations[dt] = solver
294347
self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
295348
self.work_counters['factorizations']()
296349

0 commit comments

Comments
 (0)