@@ -357,3 +357,148 @@ def plot(self, u, t=None, fig=None, quantity='T'): # pragma: no cover
357357 axs [1 ].set_ylabel (r'$z$' )
358358 fig .colorbar (imT , self .cax [0 ])
359359 fig .colorbar (imV , self .cax [1 ])
360+
361+
362+ class RayleighBenard3DHeterogeneous (RayleighBenard3D ):
363+
364+ def __init__ (self , * args , ** kwargs ):
365+ super ().__init__ (* args , ** kwargs )
366+
367+ # copy matrices we need on CPU
368+ if self .useGPU :
369+ for key in ['BC_line_zero_matrix' , 'BCs' ]: # TODO complete this list!
370+ setattr (self .spectral , key , getattr (self .spectral , key ).get ())
371+ for key in ['Pl' , 'Pr' , 'M' ]: # TODO complete this list!
372+ setattr (self , key , getattr (self , key ).get ())
373+
374+ self .L_CPU = self .L .get ()
375+ else :
376+ self .L_CPU = self .L .copy ()
377+
378+ # delete matrices we do not need on GPU
379+ for key in []: # TODO: complete list
380+ delattr (self , key )
381+
382+ def solve_system (self , rhs , dt , u0 = None , * args , skip_itransform = False , ** kwargs ):
383+ """
384+ Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by
385+ ``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``.
386+
387+ The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is.
388+ This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs.
389+
390+ Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to
391+ zero. If you want something else, it should be easy to overload this function.
392+ """
393+
394+ sp = self .spectral .sparse_lib
395+
396+ if self .spectral_space :
397+ rhs_hat = rhs .copy ()
398+ if u0 is not None :
399+ u0_hat = u0 .copy ().flatten ()
400+ else :
401+ u0_hat = None
402+ else :
403+ rhs_hat = self .spectral .transform (rhs )
404+ if u0 is not None :
405+ u0_hat = self .spectral .transform (u0 ).flatten ()
406+ else :
407+ u0_hat = None
408+
409+ # apply inverse right preconditioner to initial guess
410+ if u0_hat is not None and 'direct' not in self .solver_type :
411+ if not hasattr (self , '_Pr_inv' ):
412+ self ._PR_inv = self .linalg .splu (self .Pr .astype (complex )).solve
413+ u0_hat [...] = self ._PR_inv (u0_hat )
414+
415+ rhs_hat = (self .M @ rhs_hat .flatten ()).reshape (rhs_hat .shape )
416+ rhs_hat = self .spectral .put_BCs_in_rhs_hat (rhs_hat )
417+ rhs_hat = self .Pl @ rhs_hat .flatten ()
418+
419+ if dt not in self .cached_factorizations .keys () or not self .solver_type .lower () == 'cached_direct' :
420+ A = self .M + dt * self .L_CPU
421+ A = self .Pl @ self .spectral .put_BCs_in_matrix (A ) @ self .Pr
422+ A = self .spectral .sparse_lib .csc_matrix (A )
423+
424+ # if A.shape[0] < 200e20:
425+ # import matplotlib.pyplot as plt
426+
427+ # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
428+ # M = A # self.L
429+ # im = plt.spy(M)
430+ # plt.show()
431+
432+ if 'ilu' in self .solver_type .lower ():
433+ if dt not in self .cached_factorizations .keys ():
434+ if len (self .cached_factorizations ) >= self .max_cached_factorizations :
435+ to_evict = list (self .cached_factorizations .keys ())[0 ]
436+ self .cached_factorizations .pop (to_evict )
437+ self .logger .debug (f'Evicted matrix factorization for { to_evict = :.6f} from cache' )
438+ iLU = self .linalg .spilu (
439+ A , ** {** self .preconditioner_args , 'drop_tol' : dt * self .preconditioner_args ['drop_tol' ]}
440+ )
441+ self .cached_factorizations [dt ] = self .linalg .LinearOperator (A .shape , iLU .solve )
442+ self .logger .debug (f'Cached incomplete LU factorization for { dt = :.6f} ' )
443+ self .work_counters ['factorizations' ]()
444+ M = self .cached_factorizations [dt ]
445+ else :
446+ M = None
447+ info = 0
448+
449+ if self .solver_type .lower () == 'cached_direct' :
450+ if dt not in self .cached_factorizations .keys ():
451+ if len (self .cached_factorizations ) >= self .max_cached_factorizations :
452+ self .cached_factorizations .pop (list (self .cached_factorizations .keys ())[0 ])
453+ self .logger .debug (f'Evicted matrix factorization for { dt = :.6f} from cache' )
454+ self .cached_factorizations [dt ] = self .spectral .linalg .factorized (A )
455+ self .logger .debug (f'Cached matrix factorization for { dt = :.6f} ' )
456+ self .work_counters ['factorizations' ]()
457+
458+ _sol_hat = self .cached_factorizations [dt ](rhs_hat )
459+ self .logger .debug (f'Used cached matrix factorization for { dt = :.6f} ' )
460+
461+ elif self .solver_type .lower () == 'direct' :
462+ _sol_hat = sp .linalg .spsolve (A , rhs_hat )
463+ elif 'gmres' in self .solver_type .lower ():
464+ _sol_hat , _ = sp .linalg .gmres (
465+ A ,
466+ rhs_hat ,
467+ x0 = u0_hat ,
468+ ** self .solver_args ,
469+ callback = self .work_counters [self .solver_type ],
470+ callback_type = 'pr_norm' ,
471+ M = M ,
472+ )
473+ elif self .solver_type .lower () == 'cg' :
474+ _sol_hat , info = sp .linalg .cg (
475+ A , rhs_hat , x0 = u0_hat , ** self .solver_args , callback = self .work_counters [self .solver_type ]
476+ )
477+ elif 'bicgstab' in self .solver_type .lower ():
478+ _sol_hat , info = self .linalg .bicgstab (
479+ A ,
480+ rhs_hat ,
481+ x0 = u0_hat ,
482+ ** self .solver_args ,
483+ callback = self .work_counters [self .solver_type ],
484+ M = M ,
485+ )
486+ else :
487+ raise NotImplementedError (f'Solver { self .solver_type = } not implemented in { type (self ).__name__ } !' )
488+
489+ if info != 0 :
490+ self .logger .warn (f'{ self .solver_type } not converged! { info = } ' )
491+
492+ sol_hat = self .spectral .u_init_forward
493+ sol_hat [...] = (self .Pr @ _sol_hat ).reshape (sol_hat .shape )
494+
495+ if self .spectral_space :
496+ return sol_hat
497+ else :
498+ sol = self .spectral .u_init
499+ sol [:] = self .spectral .itransform (sol_hat ).real
500+
501+ if self .spectral .debug :
502+ self .spectral .check_BCs (sol )
503+
504+ return sol
0 commit comments