Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions pySDC/implementations/problem_classes/generic_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
left_preconditioner=True,
solver_type='cached_direct',
solver_args=None,
preconditioner_args=None,
useGPU=False,
max_cached_factorizations=12,
spectral_space=True,
Expand All @@ -83,11 +84,15 @@ def __init__(
debug (bool): Make additional tests at extra computational cost
"""
solver_args = {} if solver_args is None else solver_args
preconditioner_args = {} if preconditioner_args is None else preconditioner_args
preconditioner_args['drop_tol'] = preconditioner_args.get('drop_tol', 1e-3)
preconditioner_args['fill_factor'] = preconditioner_args.get('fill_factor', 100)
self._makeAttributeAndRegister(
'max_cached_factorizations',
'useGPU',
'solver_type',
'solver_args',
'preconditioner_args',
'left_preconditioner',
'Dirichlet_recombination',
'comm',
Expand Down Expand Up @@ -229,10 +234,14 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
rhs_hat = rhs.copy()
if u0 is not None:
u0_hat = self.Pr.T @ u0.copy().flatten()
else:
u0_hat = None
else:
rhs_hat = self.spectral.transform(rhs)
if u0 is not None:
u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten()
else:
u0_hat = None

if self.useGPU:
self.xp.cuda.Device().synchronize()
Expand All @@ -257,6 +266,23 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
# plt.colorbar(im)
# plt.show()

if 'ilu' in self.solver_type.lower():
if dt not in self.cached_factorizations.keys():
if len(self.cached_factorizations) >= self.max_cached_factorizations:
to_evict = list(self.cached_factorizations.keys())[0]
self.cached_factorizations.pop(to_evict)
self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
iLU = self.linalg.spilu(
A, **{**self.preconditioner_args, 'drop_tol': dt * self.preconditioner_args['drop_tol']}
)
self.cached_factorizations[dt] = self.linalg.LinearOperator(A.shape, iLU.solve)
self.logger.debug(f'Cached incomplete LU factorization for {dt=:.6f}')
self.work_counters['factorizations']()
M = self.cached_factorizations[dt]
else:
M = None
info = 0

if self.solver_type.lower() == 'cached_direct':
if dt not in self.cached_factorizations.keys():
if len(self.cached_factorizations) >= self.max_cached_factorizations:
Expand All @@ -271,52 +297,35 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)

elif self.solver_type.lower() == 'direct':
_sol_hat = sp.linalg.spsolve(A, rhs_hat)
elif self.solver_type.lower() == 'lsqr':
lsqr = sp.linalg.lsqr(
A,
rhs_hat,
x0=u0_hat,
**self.solver_args,
)
_sol_hat = lsqr[0]
elif self.solver_type.lower() == 'gmres':
elif 'gmres' in self.solver_type.lower():
_sol_hat, _ = sp.linalg.gmres(
A,
rhs_hat,
x0=u0_hat,
**self.solver_args,
callback=self.work_counters[self.solver_type],
callback_type='pr_norm',
M=M,
)
elif self.solver_type.lower() == 'gmres+ilu':
linalg = self.spectral.linalg

if dt not in self.cached_factorizations.keys():
if len(self.cached_factorizations) >= self.max_cached_factorizations:
to_evict = list(self.cached_factorizations.keys())[0]
self.cached_factorizations.pop(to_evict)
self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100)
self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve)
self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
self.work_counters['factorizations']()

_sol_hat, _ = linalg.gmres(
elif self.solver_type.lower() == 'cg':
_sol_hat, info = sp.linalg.cg(
A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
)
elif 'bicgstab' in self.solver_type.lower():
_sol_hat, info = self.linalg.bicgstab(
A,
rhs_hat,
x0=u0_hat,
**self.solver_args,
callback=self.work_counters[self.solver_type],
callback_type='pr_norm',
M=self.cached_factorizations[dt],
)
elif self.solver_type.lower() == 'cg':
_sol_hat, _ = sp.linalg.cg(
A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
M=M,
)
else:
raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')

if info != 0:
self.logger.warn(f'{self.solver_type} not converged! {info=}')

sol_hat = self.spectral.u_init_forward
sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)

Expand Down
5 changes: 4 additions & 1 deletion pySDC/tests/test_problems/test_heat_chebychev.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
@pytest.mark.parametrize('noise', [0, 1e-3])
@pytest.mark.parametrize('use_ultraspherical', [True, False])
@pytest.mark.parametrize('spectral_space', [True, False])
def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nvars=2**4):
@pytest.mark.parametrize('solver_type', ['cached_direct', 'direct', 'gmres', 'bicgstab', 'gmres+ilu', 'bicgstab+ilu'])
def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, solver_type, nvars=2**4):
import numpy as np

if use_ultraspherical:
Expand All @@ -25,6 +26,8 @@ def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nv
left_preconditioner=False,
debug=True,
spectral_space=spectral_space,
solver_type=solver_type,
solver_args={'rtol': 1e-12},
)

u0 = P.u_exact(0, noise=noise)
Expand Down