Skip to content

Commit 4c0d828

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
1 parent 56eea2b commit 4c0d828

File tree

5 files changed

+110
-17
lines changed

5 files changed

+110
-17
lines changed

docs/jax.experimental.pallas.mosaic_gpu.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ Functions
2424
.. autosummary::
2525
:toctree: _autosummary
2626

27+
barrier_arrive
28+
barrier_wait
2729
copy_gmem_to_smem
2830
copy_smem_to_gmem
29-
wait_barrier
31+
set_max_registers
3032
wait_smem_to_gmem
3133
wgmma
3234
wgmma_wait

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
AbstractMemoryRef = pallas_core.AbstractMemoryRef
3434

35+
DimensionSemantics = Literal["parallel", "sequential"]
36+
3537

3638
@dataclasses.dataclass(frozen=True, kw_only=True)
3739
class GPUCompilerParams(pallas_core.CompilerParams):
@@ -53,7 +55,7 @@ class GPUCompilerParams(pallas_core.CompilerParams):
5355
"""
5456
PLATFORM: ClassVar[str] = "mosaic_gpu"
5557
approx_math: bool = False
56-
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
58+
dimension_semantics: Sequence[DimensionSemantics] | None = None
5759
max_concurrent_steps: int = 1
5860
delay_release: int = 0
5961

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from __future__ import annotations
1818

19+
from typing import Literal
20+
1921
from jax._src import core as jax_core
2022
from jax._src import effects
2123
from jax._src import state
@@ -177,7 +179,8 @@ def copy_gmem_to_smem(
177179
"""Asynchronously copies a GMEM reference to a SMEM reference.
178180
179181
See also:
180-
:func:`jax.experimental.mosaic.gpu.wait_barrier`
182+
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
183+
:func:`jax.experimental.mosaic.gpu.barrier_wait`
181184
"""
182185
if src.memory_space is not gpu_core.GMEM:
183186
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
@@ -237,6 +240,52 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:
237240
raise ValueError("Barrier does not support arbirary transforms")
238241

239242

243+
class ArriveEffect(jax_core.Effect):
244+
...
245+
246+
247+
effects.control_flow_allowed_effects.add_type(ArriveEffect)
248+
249+
_arrive_effect = ArriveEffect()
250+
251+
252+
barrier_arrive_p = jax_core.Primitive("barrier_arrive")
253+
barrier_arrive_p.multiple_results = True
254+
255+
256+
@barrier_arrive_p.def_effectful_abstract_eval
257+
def _barrier_arrive_abstract_eval(*avals, **params):
258+
del avals, params # Unused.
259+
return (), {_wait_effect}
260+
261+
262+
@lowering.register_lowering_rule(barrier_arrive_p)
263+
def _barrier_arrive_lowering(
264+
ctx: lowering.LoweringRuleContext,
265+
barrier,
266+
*flat_transforms,
267+
transforms_treedef,
268+
):
269+
del ctx # Unused.
270+
transforms = transforms_treedef.unflatten(flat_transforms)
271+
indexer = _extract_barrier_indexer(transforms)
272+
if indexer is not None:
273+
barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices))
274+
barrier.arrive()
275+
return ()
276+
277+
278+
def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:
279+
"""Arrives at the given barrier."""
280+
barrier, transforms = state_primitives.get_ref_and_transforms(
281+
barrier, None, "barrier_arrive"
282+
)
283+
flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms)
284+
barrier_arrive_p.bind(
285+
barrier, *flat_transforms, transforms_treedef=transforms_treedef
286+
)
287+
288+
240289
class WaitEffect(jax_core.Effect):
241290
...
242291

@@ -245,18 +294,18 @@ class WaitEffect(jax_core.Effect):
245294
_wait_effect = WaitEffect()
246295

247296

248-
wait_barrier_p = jax_core.Primitive("wait")
249-
wait_barrier_p.multiple_results = True
297+
barrier_wait_p = jax_core.Primitive("barrier_wait")
298+
barrier_wait_p.multiple_results = True
250299

251300

252-
@wait_barrier_p.def_effectful_abstract_eval
253-
def _wait_barrier_abstract_eval(*avals, **params):
301+
@barrier_wait_p.def_effectful_abstract_eval
302+
def _barrier_wait_abstract_eval(*avals, **params):
254303
del avals, params # Unused.
255304
return (), {_wait_effect}
256305

257306

258-
@lowering.register_lowering_rule(wait_barrier_p)
259-
def _wait_barrier_lowering(
307+
@lowering.register_lowering_rule(barrier_wait_p)
308+
def _barrier_wait_lowering(
260309
ctx: lowering.LoweringRuleContext,
261310
barrier,
262311
*flat_transforms,
@@ -271,13 +320,13 @@ def _wait_barrier_lowering(
271320
return ()
272321

273322

274-
def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None:
323+
def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
275324
"""Waits on the given barrier."""
276325
barrier, transforms = state_primitives.get_ref_and_transforms(
277-
barrier, None, "wait_barrier"
326+
barrier, None, "barrier_wait"
278327
)
279328
flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms)
280-
wait_barrier_p.bind(
329+
barrier_wait_p.bind(
281330
barrier, *flat_transforms, transforms_treedef=transforms_treedef
282331
)
283332

@@ -498,3 +547,41 @@ def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
498547
del ctx
499548
nvvm_dialect.wgmma_wait_group_sync_aligned(0)
500549
return acc.value
550+
551+
552+
class SetRegistersEffect(jax_core.Effect):
553+
...
554+
555+
556+
effects.control_flow_allowed_effects.add_type(SetRegistersEffect)
557+
558+
_set_max_registers_effect = SetRegistersEffect()
559+
560+
561+
set_max_registers_p = jax_core.Primitive("set_max_registers_p")
562+
set_max_registers_p.multiple_results = True
563+
564+
565+
@set_max_registers_p.def_effectful_abstract_eval
566+
def _set_max_registers_abstract_eval(n, *, action):
567+
del n, action # Unused.
568+
return (), {_set_max_registers_effect}
569+
570+
571+
@lowering.register_lowering_rule(set_max_registers_p)
572+
def _set_max_registers_lowering(
573+
ctx: lowering.LoweringRuleContext, n, *, action
574+
):
575+
del ctx
576+
nvvm_dialect.setmaxregister(
577+
n,
578+
nvvm_dialect.SetMaxRegisterAction.increase
579+
if action == "increase"
580+
else nvvm_dialect.SetMaxRegisterAction.decrease,
581+
)
582+
return ()
583+
584+
585+
def set_max_registers(n: int, *, action: Literal["increase", "decrease"]):
586+
"""Sets the maximum number of registers owned by a warp."""
587+
set_max_registers_p.bind(n, action=action)

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
2929
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
3030
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
31-
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
31+
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait
32+
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive
33+
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers
3234
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
3335
from jax._src.pallas.mosaic_gpu.primitives import wgmma
3436
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait

tests/pallas/mosaic_gpu_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
226226
plgpu.copy_gmem_to_smem(
227227
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref
228228
)
229-
plgpu.wait_barrier(barrier_ref)
229+
plgpu.barrier_wait(barrier_ref)
230230
o_ref[...] = scratch_ref[...] + 1
231231

232232
x = jnp.arange(256).astype(jnp.float32)
@@ -247,7 +247,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
247247
plgpu.copy_gmem_to_smem(
248248
x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer]
249249
)
250-
plgpu.wait_barrier(barrier_ref.at[indexer])
250+
plgpu.barrier_wait(barrier_ref.at[indexer])
251251
o_ref[...] = scratch_ref[...] + 1
252252

253253
x = jnp.arange(128).astype(jnp.float32)
@@ -263,7 +263,7 @@ def kernel(x_ref_gmem, o_ref):
263263
def body(barrier_ref):
264264
def inner_body(scratch_ref):
265265
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
266-
plgpu.wait_barrier(barrier_ref)
266+
plgpu.barrier_wait(barrier_ref)
267267
o_ref[...] = scratch_ref[...] + 1
268268
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
269269
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
@@ -759,7 +759,7 @@ def body(step, _):
759759
slot = step % max_concurrent_steps
760760

761761
# Wait for the current GMEM->SMEM copy to complete.
762-
plgpu.wait_barrier(barrier.at[slot])
762+
plgpu.barrier_wait(barrier.at[slot])
763763
# Wait for the previous output SMEM->GMEM copy to complete.
764764
plgpu.wait_smem_to_gmem(max_concurrent_steps - 1)
765765

0 commit comments

Comments
 (0)