@@ -59,6 +59,7 @@ def __init__(
5959 left_preconditioner = True ,
6060 solver_type = 'cached_direct' ,
6161 solver_args = None ,
62+ preconditioner_args = None ,
6263 useGPU = False ,
6364 max_cached_factorizations = 12 ,
6465 spectral_space = True ,
@@ -83,11 +84,15 @@ def __init__(
8384 debug (bool): Make additional tests at extra computational cost
8485 """
8586 solver_args = {} if solver_args is None else solver_args
87+ preconditioner_args = {} if preconditioner_args is None else preconditioner_args
88+ preconditioner_args ['drop_tol' ] = preconditioner_args .get ('drop_tol' , 1e-3 )
89+ preconditioner_args ['fill_factor' ] = preconditioner_args .get ('fill_factor' , 100 )
8690 self ._makeAttributeAndRegister (
8791 'max_cached_factorizations' ,
8892 'useGPU' ,
8993 'solver_type' ,
9094 'solver_args' ,
95+ 'preconditioner_args' ,
9196 'left_preconditioner' ,
9297 'Dirichlet_recombination' ,
9398 'comm' ,
@@ -229,10 +234,14 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
229234 rhs_hat = rhs .copy ()
230235 if u0 is not None :
231236 u0_hat = self .Pr .T @ u0 .copy ().flatten ()
237+ else :
238+ u0_hat = None
232239 else :
233240 rhs_hat = self .spectral .transform (rhs )
234241 if u0 is not None :
235242 u0_hat = self .Pr .T @ self .spectral .transform (u0 ).flatten ()
243+ else :
244+ u0_hat = None
236245
237246 if self .useGPU :
238247 self .xp .cuda .Device ().synchronize ()
@@ -257,6 +266,23 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
257266 # plt.colorbar(im)
258267 # plt.show()
259268
269+ if 'ilu' in self .solver_type .lower ():
270+ if dt not in self .cached_factorizations .keys ():
271+ if len (self .cached_factorizations ) >= self .max_cached_factorizations :
272+ to_evict = list (self .cached_factorizations .keys ())[0 ]
273+ self .cached_factorizations .pop (to_evict )
274+ self .logger .debug (f'Evicted matrix factorization for { to_evict = :.6f} from cache' )
275+ iLU = self .linalg .spilu (
276+ A , ** {** self .preconditioner_args , 'drop_tol' : dt * self .preconditioner_args ['drop_tol' ]}
277+ )
278+ self .cached_factorizations [dt ] = self .linalg .LinearOperator (A .shape , iLU .solve )
279+ self .logger .debug (f'Cached incomplete LU factorization for { dt = :.6f} ' )
280+ self .work_counters ['factorizations' ]()
281+ M = self .cached_factorizations [dt ]
282+ else :
283+ M = None
284+ info = 0
285+
260286 if self .solver_type .lower () == 'cached_direct' :
261287 if dt not in self .cached_factorizations .keys ():
262288 if len (self .cached_factorizations ) >= self .max_cached_factorizations :
@@ -271,52 +297,35 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
271297
272298 elif self .solver_type .lower () == 'direct' :
273299 _sol_hat = sp .linalg .spsolve (A , rhs_hat )
274- elif self .solver_type .lower () == 'lsqr' :
275- lsqr = sp .linalg .lsqr (
276- A ,
277- rhs_hat ,
278- x0 = u0_hat ,
279- ** self .solver_args ,
280- )
281- _sol_hat = lsqr [0 ]
282- elif self .solver_type .lower () == 'gmres' :
300+ elif 'gmres' in self .solver_type .lower ():
283301 _sol_hat , _ = sp .linalg .gmres (
284302 A ,
285303 rhs_hat ,
286304 x0 = u0_hat ,
287305 ** self .solver_args ,
288306 callback = self .work_counters [self .solver_type ],
289307 callback_type = 'pr_norm' ,
308+ M = M ,
290309 )
291- elif self .solver_type .lower () == 'gmres+ilu' :
292- linalg = self .spectral .linalg
293-
294- if dt not in self .cached_factorizations .keys ():
295- if len (self .cached_factorizations ) >= self .max_cached_factorizations :
296- to_evict = list (self .cached_factorizations .keys ())[0 ]
297- self .cached_factorizations .pop (to_evict )
298- self .logger .debug (f'Evicted matrix factorization for { to_evict = :.6f} from cache' )
299- iLU = linalg .spilu (A , drop_tol = dt * 1e-4 , fill_factor = 100 )
300- self .cached_factorizations [dt ] = linalg .LinearOperator (A .shape , iLU .solve )
301- self .logger .debug (f'Cached matrix factorization for { dt = :.6f} ' )
302- self .work_counters ['factorizations' ]()
303-
304- _sol_hat , _ = linalg .gmres (
310+ elif self .solver_type .lower () == 'cg' :
311+ _sol_hat , info = sp .linalg .cg (
312+ A , rhs_hat , x0 = u0_hat , ** self .solver_args , callback = self .work_counters [self .solver_type ]
313+ )
314+ elif 'bicgstab' in self .solver_type .lower ():
315+ _sol_hat , info = self .linalg .bicgstab (
305316 A ,
306317 rhs_hat ,
307318 x0 = u0_hat ,
308319 ** self .solver_args ,
309320 callback = self .work_counters [self .solver_type ],
310- callback_type = 'pr_norm' ,
311- M = self .cached_factorizations [dt ],
312- )
313- elif self .solver_type .lower () == 'cg' :
314- _sol_hat , _ = sp .linalg .cg (
315- A , rhs_hat , x0 = u0_hat , ** self .solver_args , callback = self .work_counters [self .solver_type ]
321+ M = M ,
316322 )
317323 else :
318324 raise NotImplementedError (f'Solver { self .solver_type = } not implemented in { type (self ).__name__ } !' )
319325
326+ if info != 0 :
327+ self .logger .warn (f'{ self .solver_type } not converged! { info = } ' )
328+
320329 sol_hat = self .spectral .u_init_forward
321330 sol_hat [...] = (self .Pr @ _sol_hat ).reshape (sol_hat .shape )
322331
0 commit comments