Skip to content

Commit 1b9951e

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic GPU] Allow scalar loads in warp-level lowering.
PiperOrigin-RevId: 835174212
1 parent 79045d5 commit 1b9951e

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,9 @@ def _ndindexer_indices(
15361536

15371537

15381538
@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane)
1539+
@register_lowering_rule(
1540+
sp.get_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp
1541+
)
15391542
def _get_lowering_rule(
15401543
ctx: LoweringRuleContext, x_ref, *leaves, tree, optimized=True
15411544
):
@@ -1544,6 +1547,11 @@ def _get_lowering_rule(
15441547
"Loads from TMEM are asynchronous operations and cannot be performed"
15451548
" using the usual syntax. Please use plgpu.async_load_tmem instead."
15461549
)
1550+
if (
1551+
ctx.avals_out[0].shape
1552+
and ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp
1553+
):
1554+
raise ValueError("Can only load scalars in warp-level code.")
15471555
if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref):
15481556
raise TypeError(f"Can only load from references (got {x_ref}).")
15491557
dtype = ctx.avals_out[0].dtype

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,6 +2938,11 @@ def _load_abstract_eval(src, *avals_flat, tree, optimized):
29382938
lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane)(
29392939
lowering._get_lowering_rule
29402940
)
2941+
lowering.register_lowering_rule(
2942+
load_p, mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp
2943+
)(
2944+
lowering._get_lowering_rule
2945+
)
29412946
lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Warpgroup)(
29422947
lowering._get_lowering_rule_wg
29432948
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,6 +2592,33 @@ def _():
25922592
jnp.ones((128,), jnp.int32) * 3), axis=0)
25932593
np.testing.assert_array_equal(result, expected)
25942594

2595+
def test_scalar_load(self):
2596+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2597+
@functools.partial(self.kernel,
2598+
out_shape=jax.ShapeDtypeStruct((), jnp.int32))
2599+
def kernel(x_ref, y_ref):
2600+
@pl.core_map(warp_mesh)
2601+
def _():
2602+
warp_id = lax.axis_index("warp")
2603+
@pl.when(warp_id == 1)
2604+
def _():
2605+
y_ref[...] = x_ref[...]
2606+
np.testing.assert_array_equal(kernel(4), 4)
2607+
2608+
def test_non_scalar_load_raises(self):
2609+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
2610+
@functools.partial(self.kernel,
2611+
out_shape=jax.ShapeDtypeStruct((2,), jnp.int32))
2612+
def kernel(x_ref, y_ref):
2613+
@pl.core_map(warp_mesh)
2614+
def _():
2615+
warp_id = lax.axis_index("warp")
2616+
@pl.when(warp_id == 1)
2617+
def _():
2618+
y_ref[...] = x_ref[...]
2619+
with self.assertRaisesRegex(ValueError, "Can only load scalars",):
2620+
kernel(jnp.ones((2,), jnp.int32))
2621+
25952622
@parameterized.parameters(
25962623
lax.add, lax.sub, lax.mul, lax.div, lax.rem, lax.bitwise_and,
25972624
lax.bitwise_or, lax.bitwise_xor, lax.max, lax.min,

0 commit comments

Comments
 (0)