@@ -363,20 +363,30 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
363363
364364
365365@wait_smem_to_gmem_p .def_effectful_abstract_eval
366- def _wait_smem_to_gmem_abstract_eval (n ):
367- del n # Unused.
366+ def _wait_smem_to_gmem_abstract_eval (n , * , wait_read_only ):
367+ del n , wait_read_only # Unused.
368368 return (), {gpu_core ._memory_effect }
369369
370370
371371@lowering .register_lowering_rule (wait_smem_to_gmem_p )
372- def _wait_smem_to_gmem_lowering (ctx : lowering .LoweringRuleContext , n ):
373- ctx .launch_ctx .await_async_copy (allow_groups = n )
372+ def _wait_smem_to_gmem_lowering (
373+ ctx : lowering .LoweringRuleContext , n , * , wait_read_only
374+ ):
375+ ctx .launch_ctx .await_async_copy (
376+ allow_groups = n , await_read_only = wait_read_only
377+ )
374378 return ()
375379
376380
377- def wait_smem_to_gmem (n : int ) -> None :
378- """Waits until there are no more than ``n`` SMEM->GMEM copies in flight."""
379- wait_smem_to_gmem_p .bind (n )
381+ def wait_smem_to_gmem (n : int , wait_read_only : bool = False ) -> None :
382+ """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.
383+
384+ Args:
385+ n: The maximum number of copies in flight to wait for.
386+ wait_read_only: If ``True``, wait for the in flight copies to finish
387+ reading from SMEM. The writes to GMEM are not waited for.
388+ """
389+ wait_smem_to_gmem_p .bind (n , wait_read_only = wait_read_only )
380390
381391
382392# WGMMA on an accumulator reference
0 commit comments