3030import pyopencl .tools as cl_tools
3131
3232from arraycontext import (
33- thaw , freeze ,
33+ thaw ,
3434 with_container_arithmetic ,
3535 dataclass_array_container
3636)
4545from grudge .dof_desc import as_dofdesc , DOFDesc , DISCR_TAG_BASE , DISCR_TAG_QUAD
4646from grudge .trace_pair import TracePair
4747from grudge .discretization import DiscretizationCollection
48- from grudge .shortcuts import make_visualizer , rk4_step
48+ from grudge .shortcuts import make_visualizer , compiled_lsrk45_step
4949
5050import grudge .op as op
5151
5757
5858# {{{ wave equation bits
5959
60- @with_container_arithmetic (bcast_obj_array = True , rel_comparison = True )
60+ @with_container_arithmetic (bcast_obj_array = True , rel_comparison = True ,
61+ _cls_has_array_context_attr = True )
6162@dataclass_array_container
6263@dataclass (frozen = True )
6364class WaveState :
@@ -251,7 +252,8 @@ def main(ctx_factory, dim=2, order=3,
251252 c = 1
252253
253254 # FIXME: Sketchy, empirically determined fudge factor
254- dt = actx .to_numpy (0.45 * estimate_rk4_timestep (actx , dcoll , c ))
255+ # 5/4 to account for larger LSRK45 stability region
256+ dt = actx .to_numpy (0.45 * estimate_rk4_timestep (actx , dcoll , c )) * 5 / 4
255257
256258 vis = make_visualizer (dcoll )
257259
@@ -271,25 +273,32 @@ def rhs(t, w):
271273 istep = 0
272274 while t < t_final :
273275 start = time .time ()
274- if lazy :
275- fields = thaw (freeze (fields , actx ), actx )
276276
277- fields = rk4_step (fields , t , dt , compiled_rhs )
278-
279- l2norm = actx .to_numpy (op .norm (dcoll , fields .u , 2 ))
277+ fields = compiled_lsrk45_step (actx , fields , t , dt , compiled_rhs )
280278
281279 if istep % 10 == 0 :
282280 stop = time .time ()
283- linfnorm = actx .to_numpy (op .norm (dcoll , fields .u , np .inf ))
284- nodalmax = actx .to_numpy (op .nodal_max (dcoll , "vol" , fields .u ))
285- nodalmin = actx .to_numpy (op .nodal_min (dcoll , "vol" , fields .u ))
286- if comm .rank == 0 :
287- logger .info (f"step: { istep } t: { t } "
288- f"L2: { l2norm } "
289- f"Linf: { linfnorm } "
290- f"sol max: { nodalmax } "
291- f"sol min: { nodalmin } "
292- f"wall: { stop - start } " )
281+ if args .no_diagnostics :
282+ if comm .rank == 0 :
283+ logger .info (f"step: { istep } t: { t } "
284+ f"wall: { stop - start } " )
285+ else :
286+ l2norm = actx .to_numpy (op .norm (dcoll , fields .u , 2 ))
287+
288+ # NOTE: These are here to ensure the solution is bounded for the
289+ # time interval specified
290+ assert l2norm < 1
291+
292+ linfnorm = actx .to_numpy (op .norm (dcoll , fields .u , np .inf ))
293+ nodalmax = actx .to_numpy (op .nodal_max (dcoll , "vol" , fields .u ))
294+ nodalmin = actx .to_numpy (op .nodal_min (dcoll , "vol" , fields .u ))
295+ if comm .rank == 0 :
296+ logger .info (f"step: { istep } t: { t } "
297+ f"L2: { l2norm } "
298+ f"Linf: { linfnorm } "
299+ f"sol max: { nodalmax } "
300+ f"sol min: { nodalmin } "
301+ f"wall: { stop - start } " )
293302 if visualize :
294303 vis .write_parallel_vtk_file (
295304 comm ,
@@ -304,10 +313,6 @@ def rhs(t, w):
304313 t += dt
305314 istep += 1
306315
307- # NOTE: These are here to ensure the solution is bounded for the
308- # time interval specified
309- assert l2norm < 1
310-
311316
312317if __name__ == "__main__" :
313318 import argparse
@@ -320,6 +325,7 @@ def rhs(t, w):
320325 help = "switch to a lazy computation mode" )
321326 parser .add_argument ("--quad" , action = "store_true" )
322327 parser .add_argument ("--nonaffine" , action = "store_true" )
328+ parser .add_argument ("--no-diagnostics" , action = "store_true" )
323329
324330 args = parser .parse_args ()
325331
0 commit comments