Skip to content

Commit c3c21c7

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu_pallas] Better support for unsigned integers and floats in iota.
PiperOrigin-RevId: 701307324
1 parent ea69401 commit c3c21c7

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -706,17 +706,33 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
706706
@lowering.register_lowering_rule(broadcasted_iota_p)
707707
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
708708
del ctx
709-
undef = llvm_dialect.mlir_undef(mlir.dtype_to_ir_type(dtype))
709+
# Unsigned integers (as opposed to signless) cause MLIR verification
710+
# errors so we only use signless like Mosaic GPU does.
711+
#
712+
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
713+
mlir_dtype = (
714+
ir.IntegerType.get_signless(dtype.itemsize * 8)
715+
if jnp.issubdtype(dtype, jnp.integer)
716+
else mlir.dtype_to_ir_type(dtype)
717+
)
718+
undef = llvm_dialect.mlir_undef(mlir_dtype)
710719
is_signed = (
711720
jnp.issubdtype(dtype, jnp.signedinteger)
712721
if jnp.issubdtype(dtype, jnp.integer)
713722
else None
714723
)
715-
mlir_dtype = mlir.dtype_to_ir_type(dtype)
724+
725+
i32 = ir.IntegerType.get_signless(32)
726+
def _cast(x):
727+
if ir.FloatType.isinstance(mlir_dtype):
728+
x = arith_dialect.index_cast(i32, x)
729+
return arith_dialect.uitofp(mlir_dtype, x)
730+
else:
731+
return arith_dialect.index_cast(mlir_dtype, x)
716732
return mgpu.FragmentedArray.splat(
717733
undef, shape, layout.value, is_signed=is_signed
718734
).foreach(
719-
lambda _, idx: arith_dialect.index_cast(mlir_dtype, idx[dimension]), create_array=True, is_signed=is_signed
735+
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
720736
)
721737

722738

tests/pallas/mosaic_gpu_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def kernel(x_ref, o_ref):
241241
# are never written to.
242242
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
243243

244-
def test_iota(self):
245-
dtype, dimension = jnp.int8, 1
244+
@parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32)
245+
def test_iota(self, dtype):
246+
dimension = 1
246247
@functools.partial(
247248
pl.pallas_call,
248249
out_shape=jax.ShapeDtypeStruct((128, 128), dtype),

0 commit comments

Comments
 (0)