1616
1717from __future__ import annotations
1818
19+ from typing import Literal
20+
1921from jax ._src import core as jax_core
2022from jax ._src import effects
2123from jax ._src import state
@@ -177,7 +179,8 @@ def copy_gmem_to_smem(
177179 """Asynchronously copies a GMEM reference to a SMEM reference.
178180
179181 See also:
180- :func:`jax.experimental.mosaic.gpu.wait_barrier`
182+ :func:`jax.experimental.mosaic.gpu.barrier_arrive`
183+ :func:`jax.experimental.mosaic.gpu.barrier_wait`
181184 """
182185 if src .memory_space is not gpu_core .GMEM :
183186 raise TypeError (f"src must be a GMEM reference, got { src .memory_space } " )
@@ -237,6 +240,52 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:
237240 raise ValueError ("Barrier does not support arbirary transforms" )
238241
239242
243+ class ArriveEffect (jax_core .Effect ):
244+ ...
245+
246+
247+ effects .control_flow_allowed_effects .add_type (ArriveEffect )
248+
249+ _arrive_effect = ArriveEffect ()
250+
251+
252+ barrier_arrive_p = jax_core .Primitive ("barrier_arrive" )
253+ barrier_arrive_p .multiple_results = True
254+
255+
256+ @barrier_arrive_p .def_effectful_abstract_eval
257+ def _barrier_arrive_abstract_eval (* avals , ** params ):
258+ del avals , params # Unused.
259+ return (), {_wait_effect }
260+
261+
262+ @lowering .register_lowering_rule (barrier_arrive_p )
263+ def _barrier_arrive_lowering (
264+ ctx : lowering .LoweringRuleContext ,
265+ barrier ,
266+ * flat_transforms ,
267+ transforms_treedef ,
268+ ):
269+ del ctx # Unused.
270+ transforms = transforms_treedef .unflatten (flat_transforms )
271+ indexer = _extract_barrier_indexer (transforms )
272+ if indexer is not None :
273+ barrier = barrier .__getitem__ (* map (lowering ._as_index , indexer .indices ))
274+ barrier .arrive ()
275+ return ()
276+
277+
278+ def barrier_arrive (barrier : pallas_core .AbstractMemoryRef ) -> None :
279+ """Arrives at the given barrier."""
280+ barrier , transforms = state_primitives .get_ref_and_transforms (
281+ barrier , None , "barrier_arrive"
282+ )
283+ flat_transforms , transforms_treedef = tree_util .tree_flatten (transforms )
284+ barrier_arrive_p .bind (
285+ barrier , * flat_transforms , transforms_treedef = transforms_treedef
286+ )
287+
288+
240289class WaitEffect (jax_core .Effect ):
241290 ...
242291
@@ -245,18 +294,18 @@ class WaitEffect(jax_core.Effect):
245294_wait_effect = WaitEffect ()
246295
247296
248- wait_barrier_p = jax_core .Primitive ("wait " )
249- wait_barrier_p .multiple_results = True
297+ barrier_wait_p = jax_core .Primitive ("barrier_wait " )
298+ barrier_wait_p .multiple_results = True
250299
251300
252- @wait_barrier_p .def_effectful_abstract_eval
253- def _wait_barrier_abstract_eval (* avals , ** params ):
301+ @barrier_wait_p .def_effectful_abstract_eval
302+ def _barrier_wait_abstract_eval (* avals , ** params ):
254303 del avals , params # Unused.
255304 return (), {_wait_effect }
256305
257306
258- @lowering .register_lowering_rule (wait_barrier_p )
259- def _wait_barrier_lowering (
307+ @lowering .register_lowering_rule (barrier_wait_p )
308+ def _barrier_wait_lowering (
260309 ctx : lowering .LoweringRuleContext ,
261310 barrier ,
262311 * flat_transforms ,
@@ -271,13 +320,13 @@ def _wait_barrier_lowering(
271320 return ()
272321
273322
274- def wait_barrier (barrier : pallas_core .AbstractMemoryRef ) -> None :
323+ def barrier_wait (barrier : pallas_core .AbstractMemoryRef ) -> None :
275324 """Waits on the given barrier."""
276325 barrier , transforms = state_primitives .get_ref_and_transforms (
277- barrier , None , "wait_barrier "
326+ barrier , None , "barrier_wait "
278327 )
279328 flat_transforms , transforms_treedef = tree_util .tree_flatten (transforms )
280- wait_barrier_p .bind (
329+ barrier_wait_p .bind (
281330 barrier , * flat_transforms , transforms_treedef = transforms_treedef
282331 )
283332
@@ -498,3 +547,41 @@ def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
498547 del ctx
499548 nvvm_dialect .wgmma_wait_group_sync_aligned (0 )
500549 return acc .value
550+
551+
552+ class SetRegistersEffect (jax_core .Effect ):
553+ ...
554+
555+
556+ effects .control_flow_allowed_effects .add_type (SetRegistersEffect )
557+
558+ _set_max_registers_effect = SetRegistersEffect ()
559+
560+
561+ set_max_registers_p = jax_core .Primitive ("set_max_registers_p" )
562+ set_max_registers_p .multiple_results = True
563+
564+
565+ @set_max_registers_p .def_effectful_abstract_eval
566+ def _set_max_registers_abstract_eval (n , * , action ):
567+ del n , action # Unused.
568+ return (), {_set_max_registers_effect }
569+
570+
571+ @lowering .register_lowering_rule (set_max_registers_p )
572+ def _set_max_registers_lowering (
573+ ctx : lowering .LoweringRuleContext , n , * , action
574+ ):
575+ del ctx
576+ nvvm_dialect .setmaxregister (
577+ n ,
578+ nvvm_dialect .SetMaxRegisterAction .increase
579+ if action == "increase"
580+ else nvvm_dialect .SetMaxRegisterAction .decrease ,
581+ )
582+ return ()
583+
584+
585+ def set_max_registers (n : int , * , action : Literal ["increase" , "decrease" ]):
586+ """Sets the maximum number of registers owned by a warp."""
587+ set_max_registers_p .bind (n , action = action )
0 commit comments