@@ -364,20 +364,19 @@ class RayleighBenard3DHeterogeneous(RayleighBenard3D):
364364 def __init__ (self , * args , ** kwargs ):
365365 super ().__init__ (* args , ** kwargs )
366366
367+ CPU_only = ['BC_line_zero_matrix' , 'BCs' ]
368+ both = ['Pl' , 'Pr' , 'L' , 'M' ]
369+
367370 # copy matrices we need on CPU
368371 if self .useGPU :
369- for key in [ 'BC_line_zero_matrix' , 'BCs' ]: # TODO complete this list!
372+ for key in CPU_only :
370373 setattr (self .spectral , key , getattr (self .spectral , key ).get ())
371- for key in ['Pl' , 'Pr' , 'M' ]: # TODO complete this list!
372- setattr (self , key , getattr (self , key ).get ())
373374
374- self .L_CPU = self .L .get ()
375+ for key in both :
376+ setattr (self , f'{ key } _CPU' , getattr (self , key ).get ())
375377 else :
376- self .L_CPU = self .L .copy ()
377-
378- # delete matrices we do not need on GPU
379- for key in []: # TODO: complete list
380- delattr (self , key )
378+ for key in both :
379+ setattr (self , f'{ key } _CPU' , getattr (self , key ))
381380
382381 def solve_system (self , rhs , dt , u0 = None , * args , skip_itransform = False , ** kwargs ):
383382 """
@@ -417,8 +416,8 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
417416 rhs_hat = self .Pl @ rhs_hat .flatten ()
418417
419418 if dt not in self .cached_factorizations .keys () or not self .solver_type .lower () == 'cached_direct' :
420- A = self .M + dt * self .L_CPU
421- A = self .Pl @ self .spectral .put_BCs_in_matrix (A ) @ self .Pr
419+ A = self .M_CPU + dt * self .L_CPU
420+ A = self .Pl_CPU @ self .spectral .put_BCs_in_matrix (A ) @ self .Pr_CPU
422421 A = self .spectral .sparse_lib .csc_matrix (A )
423422
424423 # if A.shape[0] < 200e20:
0 commit comments