Skip to content

Commit 85dc966

Browse files
More efficient eval_f in RBC on GPU (#492)
1 parent 3d59549 commit 85dc966

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pySDC/implementations/problem_classes/RayleighBenard.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,20 +191,24 @@ def eval_f(self, u, *args, **kwargs):
191191
iu, iv, iT, ip = self.index(['u', 'v', 'T', 'p'])
192192

193193
# evaluate implicit terms
194-
f_impl_hat = -(self.base_change @ self.L @ u_hat.flatten()).reshape(u_hat.shape)
194+
if not hasattr(self, '_L_T_base'):
195+
self._L_T_base = self.base_change @ self.L
196+
f_impl_hat = -(self._L_T_base @ u_hat.flatten()).reshape(u_hat.shape)
195197

196198
if self.spectral_space:
197199
f.impl[:] = f_impl_hat
198200
else:
199201
f.impl[:] = self.itransform(f_impl_hat).real
200202

203+
# -------------------------------------------
201204
# treat convection explicitly with dealiasing
202-
Dx_u_hat = self.u_init_forward
203-
for i in [iu, iv, iT]:
204-
Dx_u_hat[i][:] = (Dx @ u_hat[i].flatten()).reshape(Dx_u_hat[i].shape)
205-
Dz_u_hat = self.u_init_forward
206-
for i in [iu, iv, iT]:
207-
Dz_u_hat[i][:] = (Dz @ u_hat[i].flatten()).reshape(Dz_u_hat[i].shape)
205+
206+
# start by computing derivatives
207+
if not hasattr(self, '_Dx_expanded') or not hasattr(self, '_Dz_expanded'):
208+
self._Dx_expanded = self._setup_operator({'u': {'u': Dx}, 'v': {'v': Dx}, 'T': {'T': Dx}, 'p': {}})
209+
self._Dz_expanded = self._setup_operator({'u': {'u': Dz}, 'v': {'v': Dz}, 'T': {'T': Dz}, 'p': {}})
210+
Dx_u_hat = (self._Dx_expanded @ u_hat.flatten()).reshape(u_hat.shape)
211+
Dz_u_hat = (self._Dz_expanded @ u_hat.flatten()).reshape(u_hat.shape)
208212

209213
padding = [self.dealiasing, self.dealiasing]
210214
Dx_u_pad = self.itransform(Dx_u_hat, padding=padding).real

0 commit comments

Comments
 (0)