@@ -357,147 +357,3 @@ def plot(self, u, t=None, fig=None, quantity='T'): # pragma: no cover
357357 axs [1 ].set_ylabel (r'$z$' )
358358 fig .colorbar (imT , self .cax [0 ])
359359 fig .colorbar (imV , self .cax [1 ])
360-
361-
362- class RayleighBenard3DHeterogeneous (RayleighBenard3D ):
363-
364- def __init__ (self , * args , ** kwargs ):
365- super ().__init__ (* args , ** kwargs )
366-
367- CPU_only = ['BC_line_zero_matrix' , 'BCs' ]
368- both = ['Pl' , 'Pr' , 'L' , 'M' ]
369-
370- # copy matrices we need on CPU
371- if self .useGPU :
372- for key in CPU_only :
373- setattr (self .spectral , key , getattr (self .spectral , key ).get ())
374-
375- for key in both :
376- setattr (self , f'{ key } _CPU' , getattr (self , key ).get ())
377- else :
378- for key in both :
379- setattr (self , f'{ key } _CPU' , getattr (self , key ))
380-
381- def solve_system (self , rhs , dt , u0 = None , * args , skip_itransform = False , ** kwargs ):
382- """
383- Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by
384- ``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``.
385-
386- The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is.
387- This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs.
388-
389- Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to
390- zero. If you want something else, it should be easy to overload this function.
391- """
392-
393- sp = self .spectral .sparse_lib
394-
395- if self .spectral_space :
396- rhs_hat = rhs .copy ()
397- if u0 is not None :
398- u0_hat = u0 .copy ().flatten ()
399- else :
400- u0_hat = None
401- else :
402- rhs_hat = self .spectral .transform (rhs )
403- if u0 is not None :
404- u0_hat = self .spectral .transform (u0 ).flatten ()
405- else :
406- u0_hat = None
407-
408- # apply inverse right preconditioner to initial guess
409- if u0_hat is not None and 'direct' not in self .solver_type :
410- if not hasattr (self , '_Pr_inv' ):
411- self ._PR_inv = self .linalg .splu (self .Pr .astype (complex )).solve
412- u0_hat [...] = self ._PR_inv (u0_hat )
413-
414- rhs_hat = (self .M @ rhs_hat .flatten ()).reshape (rhs_hat .shape )
415- rhs_hat = self .spectral .put_BCs_in_rhs_hat (rhs_hat )
416- rhs_hat = self .Pl @ rhs_hat .flatten ()
417-
418- if dt not in self .cached_factorizations .keys () or not self .solver_type .lower () == 'cached_direct' :
419- A = self .M_CPU + dt * self .L_CPU
420- A = self .Pl_CPU @ self .spectral .put_BCs_in_matrix (A ) @ self .Pr_CPU
421- A = self .spectral .sparse_lib .csc_matrix (A )
422-
423- # if A.shape[0] < 200e20:
424- # import matplotlib.pyplot as plt
425-
426- # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
427- # M = A # self.L
428- # im = plt.spy(M)
429- # plt.show()
430-
431- if 'ilu' in self .solver_type .lower ():
432- if dt not in self .cached_factorizations .keys ():
433- if len (self .cached_factorizations ) >= self .max_cached_factorizations :
434- to_evict = list (self .cached_factorizations .keys ())[0 ]
435- self .cached_factorizations .pop (to_evict )
436- self .logger .debug (f'Evicted matrix factorization for { to_evict = :.6f} from cache' )
437- iLU = self .linalg .spilu (
438- A , ** {** self .preconditioner_args , 'drop_tol' : dt * self .preconditioner_args ['drop_tol' ]}
439- )
440- self .cached_factorizations [dt ] = self .linalg .LinearOperator (A .shape , iLU .solve )
441- self .logger .debug (f'Cached incomplete LU factorization for { dt = :.6f} ' )
442- self .work_counters ['factorizations' ]()
443- M = self .cached_factorizations [dt ]
444- else :
445- M = None
446- info = 0
447-
448- if self .solver_type .lower () == 'cached_direct' :
449- if dt not in self .cached_factorizations .keys ():
450- if len (self .cached_factorizations ) >= self .max_cached_factorizations :
451- self .cached_factorizations .pop (list (self .cached_factorizations .keys ())[0 ])
452- self .logger .debug (f'Evicted matrix factorization for { dt = :.6f} from cache' )
453- self .cached_factorizations [dt ] = self .spectral .linalg .factorized (A )
454- self .logger .debug (f'Cached matrix factorization for { dt = :.6f} ' )
455- self .work_counters ['factorizations' ]()
456-
457- _sol_hat = self .cached_factorizations [dt ](rhs_hat )
458- self .logger .debug (f'Used cached matrix factorization for { dt = :.6f} ' )
459-
460- elif self .solver_type .lower () == 'direct' :
461- _sol_hat = sp .linalg .spsolve (A , rhs_hat )
462- elif 'gmres' in self .solver_type .lower ():
463- _sol_hat , _ = sp .linalg .gmres (
464- A ,
465- rhs_hat ,
466- x0 = u0_hat ,
467- ** self .solver_args ,
468- callback = self .work_counters [self .solver_type ],
469- callback_type = 'pr_norm' ,
470- M = M ,
471- )
472- elif self .solver_type .lower () == 'cg' :
473- _sol_hat , info = sp .linalg .cg (
474- A , rhs_hat , x0 = u0_hat , ** self .solver_args , callback = self .work_counters [self .solver_type ]
475- )
476- elif 'bicgstab' in self .solver_type .lower ():
477- _sol_hat , info = self .linalg .bicgstab (
478- A ,
479- rhs_hat ,
480- x0 = u0_hat ,
481- ** self .solver_args ,
482- callback = self .work_counters [self .solver_type ],
483- M = M ,
484- )
485- else :
486- raise NotImplementedError (f'Solver { self .solver_type = } not implemented in { type (self ).__name__ } !' )
487-
488- if info != 0 :
489- self .logger .warn (f'{ self .solver_type } not converged! { info = } ' )
490-
491- sol_hat = self .spectral .u_init_forward
492- sol_hat [...] = (self .Pr @ _sol_hat ).reshape (sol_hat .shape )
493-
494- if self .spectral_space :
495- return sol_hat
496- else :
497- sol = self .spectral .u_init
498- sol [:] = self .spectral .itransform (sol_hat ).real
499-
500- if self .spectral .debug :
501- self .spectral .check_BCs (sol )
502-
503- return sol
0 commit comments