Skip to content

Commit 3895e03

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu_pallas] Allow loading scalars or indexing arrays from gmem using splat.
PiperOrigin-RevId: 702704429
1 parent 1ddba9b commit 3895e03

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,10 +1039,15 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
10391039
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
10401040
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
10411041
raise TypeError(f"Can only load from references (got {x_smem}).")
1042+
10421043
x_aval = ctx.avals_in[0]
1044+
10431045
transforms = jax.tree.unflatten(tree, leaves)
10441046
x_smem, transforms = _handle_reshaping(x_smem, transforms)
10451047
x_smem, transforms = _handle_indexing(x_smem, transforms)
1048+
1049+
print("ctx:", ctx)
1050+
print("transforms:", transforms)
10461051
match transforms:
10471052
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
10481053
if tiling != (64, swizzle // x_aval.dtype.itemsize):
@@ -1051,6 +1056,12 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
10511056
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle
10521057
)
10531058
case ():
1059+
# Handle scalar indexing.
1060+
if not ctx.avals_out[0].shape:
1061+
is_signed = mgpu_utils.is_signed(x_aval.dtype)
1062+
val = memref_dialect.load(x_smem, [])
1063+
return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed)
1064+
10541065
return mgpu.FragmentedArray.load_strided(
10551066
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
10561067
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,18 @@ def kernel(x_ref, o_ref):
574574

575575
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
576576

577+
def test_load_scalar(self):
578+
@functools.partial(
579+
pl.pallas_call,
580+
out_shape=jax.ShapeDtypeStruct((128,), jnp.int32),
581+
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
582+
)
583+
def kernel(x_ref, o_ref):
584+
o_ref[...] = jnp.broadcast_to(x_ref[10], (128,))
585+
586+
np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)),
587+
jnp.full((128,), 10, dtype=jnp.int32))
588+
577589
def test_run_scoped(self):
578590
def kernel(x_ref, o_ref):
579591
def body(tmp_ref):

0 commit comments

Comments
 (0)