Skip to content

Commit 44333e1

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Addressed a todo in broadcasted_iota lowering
PiperOrigin-RevId: 709310152
1 parent 4eff131 commit 44333e1

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
from collections.abc import Sequence
1920
import enum
2021
import math
2122
from typing import Any, Literal
@@ -25,7 +26,6 @@
2526
from jax._src import state
2627
from jax._src import tree_util
2728
from jax._src import util
28-
from jax._src.interpreters import mlir
2929
from jax._src.lib.mlir import ir
3030
from jax._src.lib.mlir.dialects import arith as arith_dialect
3131
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
@@ -36,7 +36,8 @@
3636
from jax._src.state import discharge
3737
from jax._src.state import indexing
3838
from jax._src.state import primitives as state_primitives
39-
import jax.experimental.mosaic.gpu as mgpu
39+
from jax.experimental.mosaic import gpu as mgpu
40+
from jax.experimental.mosaic.gpu import utils as mgpu_utils
4041
import jax.numpy as jnp
4142

4243

@@ -703,38 +704,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
703704
del layout, dimension
704705
return jax_core.ShapedArray(shape, dtype)
705706

706-
@lowering.register_lowering_rule(broadcasted_iota_p)
707-
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
708-
del ctx
709-
# Unsigned integers (as opposed to signless) cause MLIR verification
710-
# errors so we only use signless like Mosaic GPU does.
711-
#
712-
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
713-
mlir_dtype = (
714-
ir.IntegerType.get_signless(dtype.itemsize * 8)
715-
if jnp.issubdtype(dtype, jnp.integer)
716-
else mlir.dtype_to_ir_type(dtype)
717-
)
718-
undef = llvm_dialect.mlir_undef(mlir_dtype)
719-
is_signed = (
720-
jnp.issubdtype(dtype, jnp.signedinteger)
721-
if jnp.issubdtype(dtype, jnp.integer)
722-
else None
723-
)
724707

725-
i32 = ir.IntegerType.get_signless(32)
726-
def _cast(x):
727-
if ir.FloatType.isinstance(mlir_dtype):
728-
x = arith_dialect.index_cast(i32, x)
729-
return arith_dialect.uitofp(mlir_dtype, x)
730-
else:
731-
return arith_dialect.index_cast(mlir_dtype, x)
708+
@lowering.register_lowering_rule(broadcasted_iota_p)
709+
def _broadcasted_iota_lowering(
710+
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
711+
):
712+
del ctx # Unused.
713+
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
714+
if ir.FloatType.isinstance(mlir_dtype):
715+
i32 = ir.IntegerType.get_signless(32)
716+
cast = lambda x: arith_dialect.uitofp(
717+
mlir_dtype, arith_dialect.index_cast(i32, x)
718+
)
719+
else:
720+
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
721+
is_signed = mgpu_utils.is_signed(dtype)
732722
return mgpu.FragmentedArray.splat(
733-
undef, shape, layout.value, is_signed=is_signed
723+
llvm_dialect.mlir_undef(mlir_dtype),
724+
shape,
725+
layout.value,
726+
is_signed=is_signed,
734727
).foreach(
735-
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
728+
lambda _, idx: cast(idx[dimension]),
729+
create_array=True,
730+
is_signed=is_signed,
736731
)
737732

738733

739-
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
740-
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
734+
def broadcasted_iota(
735+
dtype: jax.typing.DTypeLike,
736+
shape: Sequence[int],
737+
dimension: int,
738+
*,
739+
layout: Layout | None = None,
740+
) -> jax.Array:
741+
return broadcasted_iota_p.bind(
742+
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
743+
)

0 commit comments

Comments
 (0)