Skip to content

Commit ed6cdec

Browse files
Refactored iterative solvers in generic spectral class (#560)
1 parent a8c7111 commit ed6cdec

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

pySDC/implementations/problem_classes/generic_spectral.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
left_preconditioner=True,
6060
solver_type='cached_direct',
6161
solver_args=None,
62+
preconditioner_args=None,
6263
useGPU=False,
6364
max_cached_factorizations=12,
6465
spectral_space=True,
@@ -83,11 +84,15 @@ def __init__(
8384
debug (bool): Make additional tests at extra computational cost
8485
"""
8586
solver_args = {} if solver_args is None else solver_args
87+
preconditioner_args = {} if preconditioner_args is None else preconditioner_args
88+
preconditioner_args['drop_tol'] = preconditioner_args.get('drop_tol', 1e-3)
89+
preconditioner_args['fill_factor'] = preconditioner_args.get('fill_factor', 100)
8690
self._makeAttributeAndRegister(
8791
'max_cached_factorizations',
8892
'useGPU',
8993
'solver_type',
9094
'solver_args',
95+
'preconditioner_args',
9196
'left_preconditioner',
9297
'Dirichlet_recombination',
9398
'comm',
@@ -229,10 +234,14 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
229234
rhs_hat = rhs.copy()
230235
if u0 is not None:
231236
u0_hat = self.Pr.T @ u0.copy().flatten()
237+
else:
238+
u0_hat = None
232239
else:
233240
rhs_hat = self.spectral.transform(rhs)
234241
if u0 is not None:
235242
u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten()
243+
else:
244+
u0_hat = None
236245

237246
if self.useGPU:
238247
self.xp.cuda.Device().synchronize()
@@ -257,6 +266,23 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
257266
# plt.colorbar(im)
258267
# plt.show()
259268

269+
if 'ilu' in self.solver_type.lower():
270+
if dt not in self.cached_factorizations.keys():
271+
if len(self.cached_factorizations) >= self.max_cached_factorizations:
272+
to_evict = list(self.cached_factorizations.keys())[0]
273+
self.cached_factorizations.pop(to_evict)
274+
self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
275+
iLU = self.linalg.spilu(
276+
A, **{**self.preconditioner_args, 'drop_tol': dt * self.preconditioner_args['drop_tol']}
277+
)
278+
self.cached_factorizations[dt] = self.linalg.LinearOperator(A.shape, iLU.solve)
279+
self.logger.debug(f'Cached incomplete LU factorization for {dt=:.6f}')
280+
self.work_counters['factorizations']()
281+
M = self.cached_factorizations[dt]
282+
else:
283+
M = None
284+
info = 0
285+
260286
if self.solver_type.lower() == 'cached_direct':
261287
if dt not in self.cached_factorizations.keys():
262288
if len(self.cached_factorizations) >= self.max_cached_factorizations:
@@ -271,52 +297,35 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
271297

272298
elif self.solver_type.lower() == 'direct':
273299
_sol_hat = sp.linalg.spsolve(A, rhs_hat)
274-
elif self.solver_type.lower() == 'lsqr':
275-
lsqr = sp.linalg.lsqr(
276-
A,
277-
rhs_hat,
278-
x0=u0_hat,
279-
**self.solver_args,
280-
)
281-
_sol_hat = lsqr[0]
282-
elif self.solver_type.lower() == 'gmres':
300+
elif 'gmres' in self.solver_type.lower():
283301
_sol_hat, _ = sp.linalg.gmres(
284302
A,
285303
rhs_hat,
286304
x0=u0_hat,
287305
**self.solver_args,
288306
callback=self.work_counters[self.solver_type],
289307
callback_type='pr_norm',
308+
M=M,
290309
)
291-
elif self.solver_type.lower() == 'gmres+ilu':
292-
linalg = self.spectral.linalg
293-
294-
if dt not in self.cached_factorizations.keys():
295-
if len(self.cached_factorizations) >= self.max_cached_factorizations:
296-
to_evict = list(self.cached_factorizations.keys())[0]
297-
self.cached_factorizations.pop(to_evict)
298-
self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
299-
iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100)
300-
self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve)
301-
self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
302-
self.work_counters['factorizations']()
303-
304-
_sol_hat, _ = linalg.gmres(
310+
elif self.solver_type.lower() == 'cg':
311+
_sol_hat, info = sp.linalg.cg(
312+
A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
313+
)
314+
elif 'bicgstab' in self.solver_type.lower():
315+
_sol_hat, info = self.linalg.bicgstab(
305316
A,
306317
rhs_hat,
307318
x0=u0_hat,
308319
**self.solver_args,
309320
callback=self.work_counters[self.solver_type],
310-
callback_type='pr_norm',
311-
M=self.cached_factorizations[dt],
312-
)
313-
elif self.solver_type.lower() == 'cg':
314-
_sol_hat, _ = sp.linalg.cg(
315-
A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
321+
M=M,
316322
)
317323
else:
318324
raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')
319325

326+
if info != 0:
327+
self.logger.warn(f'{self.solver_type} not converged! {info=}')
328+
320329
sol_hat = self.spectral.u_init_forward
321330
sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)
322331

pySDC/tests/test_problems/test_heat_chebychev.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
@pytest.mark.parametrize('noise', [0, 1e-3])
99
@pytest.mark.parametrize('use_ultraspherical', [True, False])
1010
@pytest.mark.parametrize('spectral_space', [True, False])
11-
def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nvars=2**4):
11+
@pytest.mark.parametrize('solver_type', ['cached_direct', 'direct', 'gmres', 'bicgstab', 'gmres+ilu', 'bicgstab+ilu'])
12+
def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, solver_type, nvars=2**4):
1213
import numpy as np
1314

1415
if use_ultraspherical:
@@ -25,6 +26,8 @@ def test_heat1d_chebychev(a, b, f, noise, use_ultraspherical, spectral_space, nv
2526
left_preconditioner=False,
2627
debug=True,
2728
spectral_space=spectral_space,
29+
solver_type=solver_type,
30+
solver_args={'rtol': 1e-12},
2831
)
2932

3033
u0 = P.u_exact(0, noise=noise)

0 commit comments

Comments
 (0)