Skip to content

Commit c4ef614

Browse files
committed
Optimized eval_f for memory consumption in 3D RBC
1 parent 30c48e9 commit c4ef614

File tree

1 file changed

+86
-38
lines changed

1 file changed

+86
-38
lines changed

pySDC/implementations/problem_classes/RayleighBenard3D.py

Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class RayleighBenard3D(GenericSpectralLinear):
1313
"""
14-
Rayleigh-Benard Convection is a variation of incompressible fluid dynamics.
14+
Rayleigh-Benard Convection is a variation of incompressible Navier-Stokes.
1515
1616
The equations we solve are
1717
@@ -28,11 +28,10 @@ class RayleighBenard3D(GenericSpectralLinear):
2828
2929
The domain, vertical boundary conditions and pressure gauge are
3030
31-
Omega = [0, Lx) x [0, Ly) x (0, Lz)
32-
T(z=Lz) = 0
33-
T(z=0) = Lz
34-
u(z=0) = v(z=0) = w(z=0) = 0
35-
u(z=Lz) = v(z=Lz) = w(z=Lz) = 0
31+
Omega = [0, 8) x (-1, 1)
32+
T(z=+1) = 0
33+
T(z=-1) = 2
34+
u(z=+-1) = v(z=+-1) = 0
3635
integral over p = 0
3736
3837
The spectral discretization uses FFT horizontally, implying periodic BCs, and an ultraspherical method vertically to
@@ -150,6 +149,8 @@ def __init__(
150149
self.Dyy = Dyy
151150
self.Dz = S1 @ Dz
152151
self.Dzz = S2 @ Dzz
152+
self.S2 = S2
153+
self.S1 = S1
153154

154155
# compute rescaled Rayleigh number to extract viscosity and thermal diffusivity
155156
Ra = Rayleigh / (max([abs(BCs['T_top'] - BCs['T_bottom']), np.finfo(float).eps]) * self.axes[2].L ** 3)
@@ -170,9 +171,6 @@ def __init__(
170171
M_lhs = {i: {i: U02 @ Id} for i in ['u', 'v', 'w', 'T']}
171172
self.setup_M(M_lhs)
172173

173-
# Prepare going from second (first for divergence free equation) derivative basis back to Chebychev-T
174-
self.base_change = self._setup_operator({**{comp: {comp: S2} for comp in ['u', 'v', 'w', 'T']}, 'p': {'p': S1}})
175-
176174
# BCs
177175
self.add_BC(
178176
component='p', equation='p', axis=2, v=self.BCs['p_integral'], kind='integral', line=-1, scalar=True
@@ -224,53 +222,43 @@ def eval_f(self, u, *args, **kwargs):
224222
f_impl_hat = self.u_init_forward
225223

226224
iu, iv, iw, iT, ip = self.index(['u', 'v', 'w', 'T', 'p'])
225+
derivative_indices = [iu, iv, iw, iT]
227226

228227
# evaluate implicit terms
229-
if not hasattr(self, '_L_T_base'):
230-
self._L_T_base = self.base_change @ self.L
231-
f_impl_hat = -(self._L_T_base @ u_hat.flatten()).reshape(u_hat.shape)
228+
f_impl_hat = -(self.L @ u_hat.flatten()).reshape(u_hat.shape)
229+
for i in derivative_indices:
230+
f_impl_hat[i] = (self.S2 @ f_impl_hat[i].flatten()).reshape(f_impl_hat[i].shape)
231+
f_impl_hat[ip] = (self.S1 @ f_impl_hat[ip].flatten()).reshape(f_impl_hat[ip].shape)
232232

233233
if self.spectral_space:
234-
f.impl[:] = f_impl_hat
234+
self.xp.copyto(f.impl, f_impl_hat)
235235
else:
236236
f.impl[:] = self.itransform(f_impl_hat).real
237237

238238
# -------------------------------------------
239239
# treat convection explicitly with dealiasing
240240

241241
# start by computing derivatives
242-
if not hasattr(self, '_Dx_expanded') or not hasattr(self, '_Dz_expanded'):
243-
Dz = self.Dz
244-
Dy = self.Dy
245-
Dx = self.Dx
242+
padding = (self.dealiasing,) * self.ndim
243+
derivatives = []
244+
u_hat_flat = [u_hat[i].flatten() for i in derivative_indices]
246245

247-
self._Dx_expanded = self._setup_operator(
248-
{'u': {'u': Dx}, 'v': {'v': Dx}, 'w': {'w': Dx}, 'T': {'T': Dx}, 'p': {}}
249-
)
250-
self._Dy_expanded = self._setup_operator(
251-
{'u': {'u': Dy}, 'v': {'v': Dy}, 'w': {'w': Dy}, 'T': {'T': Dy}, 'p': {}}
252-
)
253-
self._Dz_expanded = self._setup_operator(
254-
{'u': {'u': Dz}, 'v': {'v': Dz}, 'w': {'w': Dz}, 'T': {'T': Dz}, 'p': {}}
255-
)
256-
Dx_u_hat = (self._Dx_expanded @ u_hat.flatten()).reshape(u_hat.shape)
257-
Dy_u_hat = (self._Dy_expanded @ u_hat.flatten()).reshape(u_hat.shape)
258-
Dz_u_hat = (self._Dz_expanded @ u_hat.flatten()).reshape(u_hat.shape)
246+
_D_u_hat = self.u_init_forward
247+
for D in [self.Dx, self.Dy, self.Dz]:
248+
_D_u_hat *= 0
249+
for i in derivative_indices:
250+
self.xp.copyto(_D_u_hat[i], (D @ u_hat_flat[i]).reshape(_D_u_hat[i].shape))
251+
derivatives.append(self.itransform(_D_u_hat, padding=padding).real)
259252

260-
padding = (self.dealiasing,) * self.ndim
261-
Dx_u_pad = self.itransform(Dx_u_hat, padding=padding).real
262-
Dy_u_pad = self.itransform(Dy_u_hat, padding=padding).real
263-
Dz_u_pad = self.itransform(Dz_u_hat, padding=padding).real
264253
u_pad = self.itransform(u_hat, padding=padding).real
265254

266255
fexpl_pad = self.xp.zeros_like(u_pad)
267-
fexpl_pad[iu][:] = -(u_pad[iu] * Dx_u_pad[iu] + u_pad[iv] * Dy_u_pad[iu] + u_pad[iw] * Dz_u_pad[iu])
268-
fexpl_pad[iv][:] = -(u_pad[iu] * Dx_u_pad[iv] + u_pad[iv] * Dy_u_pad[iv] + u_pad[iw] * Dz_u_pad[iv])
269-
fexpl_pad[iw][:] = -(u_pad[iu] * Dx_u_pad[iw] + u_pad[iv] * Dy_u_pad[iw] + u_pad[iw] * Dz_u_pad[iw])
270-
fexpl_pad[iT][:] = -(u_pad[iu] * Dx_u_pad[iT] + u_pad[iv] * Dy_u_pad[iT] + u_pad[iw] * Dz_u_pad[iT])
256+
for i in derivative_indices:
257+
for i_vel, iD in zip([iu, iv, iw], range(self.ndim)):
258+
fexpl_pad[i] -= u_pad[i_vel] * derivatives[iD][i]
271259

272260
if self.spectral_space:
273-
f.expl[:] = self.transform(fexpl_pad, padding=padding)
261+
self.xp.copyto(f.expl, self.transform(fexpl_pad, padding=padding))
274262
else:
275263
f.expl[:] = self.itransform(self.transform(fexpl_pad, padding=padding)).real
276264

@@ -307,3 +295,63 @@ def u_exact(self, t=0, noise_level=1e-3, seed=99):
307295
return me_hat
308296
else:
309297
return me
298+
299+
def get_fig(self): # pragma: no cover
300+
"""
301+
Get a figure suitable to plot the solution of this problem
302+
303+
Returns
304+
-------
305+
self.fig : matplotlib.pyplot.figure.Figure
306+
"""
307+
import matplotlib.pyplot as plt
308+
from mpl_toolkits.axes_grid1 import make_axes_locatable
309+
310+
plt.rcParams['figure.constrained_layout.use'] = True
311+
self.fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, figsize=((10, 5)))
312+
self.cax = []
313+
divider = make_axes_locatable(axs[0])
314+
self.cax += [divider.append_axes('right', size='3%', pad=0.03)]
315+
divider2 = make_axes_locatable(axs[1])
316+
self.cax += [divider2.append_axes('right', size='3%', pad=0.03)]
317+
return self.fig
318+
319+
def plot(self, u, t=None, fig=None, quantity='T'): # pragma: no cover
320+
r"""
321+
Plot the solution.
322+
323+
Parameters
324+
----------
325+
u : dtype_u
326+
Solution to be plotted
327+
t : float
328+
Time to display at the top of the figure
329+
fig : matplotlib.pyplot.figure.Figure
330+
Figure with the same structure as a figure generated by `self.get_fig`. If none is supplied, a new figure will be generated.
331+
quantity : (str)
332+
quantity you want to plot
333+
334+
Returns
335+
-------
336+
None
337+
"""
338+
fig = self.get_fig() if fig is None else fig
339+
axs = fig.axes
340+
341+
imV = axs[1].pcolormesh(self.X, self.Z, self.compute_vorticity(u).real)
342+
343+
if self.spectral_space:
344+
u = self.itransform(u)
345+
346+
imT = axs[0].pcolormesh(self.X, self.Z, u[self.index(quantity)].real)
347+
348+
for i, label in zip([0, 1], [rf'${quantity}$', 'vorticity']):
349+
axs[i].set_aspect(1)
350+
axs[i].set_title(label)
351+
352+
if t is not None:
353+
fig.suptitle(f't = {t:.2f}')
354+
axs[1].set_xlabel(r'$x$')
355+
axs[1].set_ylabel(r'$z$')
356+
fig.colorbar(imT, self.cax[0])
357+
fig.colorbar(imV, self.cax[1])

0 commit comments

Comments
 (0)