Skip to content

Commit 13e2109

Browse files
Thomasbrownbaerchen
authored andcommitted
Refactor
1 parent 76f0d93 commit 13e2109

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

pySDC/implementations/problem_classes/RayleighBenard3D.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,19 @@ class RayleighBenard3DHeterogeneous(RayleighBenard3D):
364364
def __init__(self, *args, **kwargs):
365365
super().__init__(*args, **kwargs)
366366

367+
CPU_only = ['BC_line_zero_matrix', 'BCs']
368+
both = ['Pl', 'Pr', 'L', 'M']
369+
367370
# copy matrices we need on CPU
368371
if self.useGPU:
369-
for key in ['BC_line_zero_matrix', 'BCs']: # TODO complete this list!
372+
for key in CPU_only:
370373
setattr(self.spectral, key, getattr(self.spectral, key).get())
371-
for key in ['Pl', 'Pr', 'M']: # TODO complete this list!
372-
setattr(self, key, getattr(self, key).get())
373374

374-
self.L_CPU = self.L.get()
375+
for key in both:
376+
setattr(self, f'{key}_CPU', getattr(self, key).get())
375377
else:
376-
self.L_CPU = self.L.copy()
377-
378-
# delete matrices we do not need on GPU
379-
for key in []: # TODO: complete list
380-
delattr(self, key)
378+
for key in both:
379+
setattr(self, f'{key}_CPU', getattr(self, key))
381380

382381
def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
383382
"""
@@ -417,8 +416,8 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
417416
rhs_hat = self.Pl @ rhs_hat.flatten()
418417

419418
if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
420-
A = self.M + dt * self.L_CPU
421-
A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr
419+
A = self.M_CPU + dt * self.L_CPU
420+
A = self.Pl_CPU @ self.spectral.put_BCs_in_matrix(A) @ self.Pr_CPU
422421
A = self.spectral.sparse_lib.csc_matrix(A)
423422

424423
# if A.shape[0] < 200e20:

pySDC/tests/test_problems/test_RayleighBenard3D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def test_banded_matrix(preconditioning):
277277
def test_heterogeneous_implementation():
278278
from pySDC.implementations.problem_classes.RayleighBenard3D import RayleighBenard3D, RayleighBenard3DHeterogeneous
279279

280-
params = {'nx': 2, 'ny': 2, 'nz': 2, 'useGPU': False}
280+
params = {'nx': 2, 'ny': 2, 'nz': 2, 'useGPU': True}
281281
gpu = RayleighBenard3D(**params)
282282
het = RayleighBenard3DHeterogeneous(**params)
283283

0 commit comments

Comments
 (0)