Skip to content

Commit 8477580

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu pallas] Layout iota operation.
PiperOrigin-RevId: 700711177
1 parent f3acfa9 commit 8477580

File tree

4 files changed

+43
-1
lines changed

4 files changed

+43
-1
lines changed

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ pytype_strict_library(
9191
":lowering",
9292
"//jax",
9393
"//jax:core",
94-
"//jax:effects",
94+
"//jax:mlir",
9595
"//jax:mosaic_gpu",
9696
"//jax:tree_util",
9797
"//jax:util",

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
from jax._src import state
2626
from jax._src import tree_util
2727
from jax._src import util
28+
from jax._src.interpreters import mlir
2829
from jax._src.lib.mlir import ir
2930
from jax._src.lib.mlir.dialects import arith as arith_dialect
31+
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
3032
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
3133
from jax._src.pallas import core as pallas_core
3234
from jax._src.pallas.mosaic_gpu import core as gpu_core
@@ -692,3 +694,31 @@ def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
692694
def commit_smem():
693695
"""Commits all writes to SMEM, making them visible to loads, TMA and WGMMA."""
694696
commit_smem_p.bind()
697+
698+
699+
broadcasted_iota_p = jax_core.Primitive("broadcasted_iota")
700+
701+
@broadcasted_iota_p.def_abstract_eval
702+
def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
703+
del layout, dimension
704+
return jax_core.ShapedArray(shape, dtype)
705+
706+
@lowering.register_lowering_rule(broadcasted_iota_p)
707+
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
708+
del ctx
709+
undef = llvm_dialect.mlir_undef(mlir.dtype_to_ir_type(dtype))
710+
is_signed = (
711+
jnp.issubdtype(dtype, jnp.signedinteger)
712+
if jnp.issubdtype(dtype, jnp.integer)
713+
else None
714+
)
715+
mlir_dtype = mlir.dtype_to_ir_type(dtype)
716+
return mgpu.FragmentedArray.splat(
717+
undef, shape, layout.value, is_signed=is_signed
718+
).foreach(
719+
lambda _, idx: arith_dialect.index_cast(mlir_dtype, idx[dimension]), create_array=True, is_signed=is_signed
720+
)
721+
722+
723+
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
724+
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
3737
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
3838
from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast
39+
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
3940
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
4041
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
4142
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma

tests/pallas/mosaic_gpu_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,17 @@ def kernel(x_ref, o_ref):
241241
# are never written to.
242242
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
243243

244+
def test_iota(self):
245+
dtype, dimension = jnp.int8, 1
246+
@functools.partial(
247+
pl.pallas_call,
248+
out_shape=jax.ShapeDtypeStruct((128, 128), dtype),
249+
)
250+
def kernel(o_ref):
251+
o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA)
252+
253+
np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension))
254+
244255
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
245256
def test_copy_smem_to_gmem(self, indexer):
246257
@functools.partial(

0 commit comments

Comments
 (0)