Skip to content

Commit ec17d80

Browse files
committed
Merge remote-tracking branch 'origin/main' into multi-volume
2 parents ff6c73e + d4b21e0 commit ec17d80

File tree

9 files changed

+548
-160
lines changed

9 files changed

+548
-160
lines changed

examples/wave/wave-op-mpi.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import pyopencl.tools as cl_tools
3131

3232
from arraycontext import (
33-
thaw, freeze,
33+
thaw,
3434
with_container_arithmetic,
3535
dataclass_array_container
3636
)
@@ -45,7 +45,7 @@
4545
from grudge.dof_desc import as_dofdesc, DOFDesc, DISCR_TAG_BASE, DISCR_TAG_QUAD
4646
from grudge.trace_pair import TracePair
4747
from grudge.discretization import DiscretizationCollection
48-
from grudge.shortcuts import make_visualizer, rk4_step
48+
from grudge.shortcuts import make_visualizer, compiled_lsrk45_step
4949

5050
import grudge.op as op
5151

@@ -57,7 +57,8 @@
5757

5858
# {{{ wave equation bits
5959

60-
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
60+
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True,
61+
_cls_has_array_context_attr=True)
6162
@dataclass_array_container
6263
@dataclass(frozen=True)
6364
class WaveState:
@@ -251,7 +252,8 @@ def main(ctx_factory, dim=2, order=3,
251252
c = 1
252253

253254
# FIXME: Sketchy, empirically determined fudge factor
254-
dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c))
255+
# 5/4 to account for larger LSRK45 stability region
256+
dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c)) * 5/4
255257

256258
vis = make_visualizer(dcoll)
257259

@@ -271,25 +273,32 @@ def rhs(t, w):
271273
istep = 0
272274
while t < t_final:
273275
start = time.time()
274-
if lazy:
275-
fields = thaw(freeze(fields, actx), actx)
276276

277-
fields = rk4_step(fields, t, dt, compiled_rhs)
278-
279-
l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2))
277+
fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs)
280278

281279
if istep % 10 == 0:
282280
stop = time.time()
283-
linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf))
284-
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u))
285-
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u))
286-
if comm.rank == 0:
287-
logger.info(f"step: {istep} t: {t} "
288-
f"L2: {l2norm} "
289-
f"Linf: {linfnorm} "
290-
f"sol max: {nodalmax} "
291-
f"sol min: {nodalmin} "
292-
f"wall: {stop-start} ")
281+
if args.no_diagnostics:
282+
if comm.rank == 0:
283+
logger.info(f"step: {istep} t: {t} "
284+
f"wall: {stop-start} ")
285+
else:
286+
l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2))
287+
288+
# NOTE: These are here to ensure the solution is bounded for the
289+
# time interval specified
290+
assert l2norm < 1
291+
292+
linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf))
293+
nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u))
294+
nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u))
295+
if comm.rank == 0:
296+
logger.info(f"step: {istep} t: {t} "
297+
f"L2: {l2norm} "
298+
f"Linf: {linfnorm} "
299+
f"sol max: {nodalmax} "
300+
f"sol min: {nodalmin} "
301+
f"wall: {stop-start} ")
293302
if visualize:
294303
vis.write_parallel_vtk_file(
295304
comm,
@@ -304,10 +313,6 @@ def rhs(t, w):
304313
t += dt
305314
istep += 1
306315

307-
# NOTE: These are here to ensure the solution is bounded for the
308-
# time interval specified
309-
assert l2norm < 1
310-
311316

312317
if __name__ == "__main__":
313318
import argparse
@@ -320,6 +325,7 @@ def rhs(t, w):
320325
help="switch to a lazy computation mode")
321326
parser.add_argument("--quad", action="store_true")
322327
parser.add_argument("--nonaffine", action="store_true")
328+
parser.add_argument("--no-diagnostics", action="store_true")
323329

324330
args = parser.parse_args()
325331

grudge/models/euler.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,10 @@ def operator(self, t, q):
318318
qtag = self.qtag
319319
dq = DOFDesc("vol", qtag)
320320
df = DOFDesc("all_faces", qtag)
321-
df_int = DOFDesc("int_faces", qtag)
322321

323322
def interp_to_quad(u):
324323
return op.project(dcoll, "vol", dq, u)
325324

326-
def interp_to_quad_surf(u):
327-
return TracePair(
328-
df_int,
329-
interior=op.project(dcoll, "int_faces", df_int, u.int),
330-
exterior=op.project(dcoll, "int_faces", df_int, u.ext)
331-
)
332-
333325
# Compute volume fluxes
334326
volume_fluxes = op.weak_local_div(
335327
dcoll, dq,
@@ -341,7 +333,7 @@ def interp_to_quad_surf(u):
341333
sum(
342334
euler_numerical_flux(
343335
dcoll,
344-
interp_to_quad_surf(tpair),
336+
op.tracepair_with_discr_tag(dcoll, qtag, tpair),
345337
gamma=gamma,
346338
lf_stabilization=self.lf_stabilization
347339
) for tpair in op.interior_trace_pairs(dcoll, q)

0 commit comments

Comments
 (0)