|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +from collections.abc import Sequence |
19 | 20 | import enum |
20 | 21 | import math |
21 | 22 | from typing import Any, Literal |
|
25 | 26 | from jax._src import state |
26 | 27 | from jax._src import tree_util |
27 | 28 | from jax._src import util |
28 | | -from jax._src.interpreters import mlir |
29 | 29 | from jax._src.lib.mlir import ir |
30 | 30 | from jax._src.lib.mlir.dialects import arith as arith_dialect |
31 | 31 | from jax._src.lib.mlir.dialects import llvm as llvm_dialect |
|
36 | 36 | from jax._src.state import discharge |
37 | 37 | from jax._src.state import indexing |
38 | 38 | from jax._src.state import primitives as state_primitives |
39 | | -import jax.experimental.mosaic.gpu as mgpu |
| 39 | +from jax.experimental.mosaic import gpu as mgpu |
| 40 | +from jax.experimental.mosaic.gpu import utils as mgpu_utils |
40 | 41 | import jax.numpy as jnp |
41 | 42 |
|
42 | 43 |
|
@@ -703,38 +704,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): |
703 | 704 | del layout, dimension |
704 | 705 | return jax_core.ShapedArray(shape, dtype) |
705 | 706 |
|
706 | | -@lowering.register_lowering_rule(broadcasted_iota_p) |
707 | | -def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout): |
708 | | - del ctx |
709 | | - # Unsigned integers (as opposed to signless) cause MLIR verification |
710 | | - # errors so we only use signless like Mosaic GPU does. |
711 | | - # |
712 | | - # TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead. |
713 | | - mlir_dtype = ( |
714 | | - ir.IntegerType.get_signless(dtype.itemsize * 8) |
715 | | - if jnp.issubdtype(dtype, jnp.integer) |
716 | | - else mlir.dtype_to_ir_type(dtype) |
717 | | - ) |
718 | | - undef = llvm_dialect.mlir_undef(mlir_dtype) |
719 | | - is_signed = ( |
720 | | - jnp.issubdtype(dtype, jnp.signedinteger) |
721 | | - if jnp.issubdtype(dtype, jnp.integer) |
722 | | - else None |
723 | | - ) |
724 | 707 |
|
725 | | - i32 = ir.IntegerType.get_signless(32) |
726 | | - def _cast(x): |
727 | | - if ir.FloatType.isinstance(mlir_dtype): |
728 | | - x = arith_dialect.index_cast(i32, x) |
729 | | - return arith_dialect.uitofp(mlir_dtype, x) |
730 | | - else: |
731 | | - return arith_dialect.index_cast(mlir_dtype, x) |
| 708 | +@lowering.register_lowering_rule(broadcasted_iota_p) |
| 709 | +def _broadcasted_iota_lowering( |
| 710 | + ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout |
| 711 | +): |
| 712 | + del ctx # Unused. |
| 713 | + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) |
| 714 | + if ir.FloatType.isinstance(mlir_dtype): |
| 715 | + i32 = ir.IntegerType.get_signless(32) |
| 716 | + cast = lambda x: arith_dialect.uitofp( |
| 717 | + mlir_dtype, arith_dialect.index_cast(i32, x) |
| 718 | + ) |
| 719 | + else: |
| 720 | + cast = lambda x: arith_dialect.index_cast(mlir_dtype, x) |
| 721 | + is_signed = mgpu_utils.is_signed(dtype) |
732 | 722 | return mgpu.FragmentedArray.splat( |
733 | | - undef, shape, layout.value, is_signed=is_signed |
| 723 | + llvm_dialect.mlir_undef(mlir_dtype), |
| 724 | + shape, |
| 725 | + layout.value, |
| 726 | + is_signed=is_signed, |
734 | 727 | ).foreach( |
735 | | - lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed |
| 728 | + lambda _, idx: cast(idx[dimension]), |
| 729 | + create_array=True, |
| 730 | + is_signed=is_signed, |
736 | 731 | ) |
737 | 732 |
|
738 | 733 |
|
739 | | -def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None): |
740 | | - return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout) |
| 734 | +def broadcasted_iota( |
| 735 | + dtype: jax.typing.DTypeLike, |
| 736 | + shape: Sequence[int], |
| 737 | + dimension: int, |
| 738 | + *, |
| 739 | + layout: Layout | None = None, |
| 740 | +) -> jax.Array: |
| 741 | + return broadcasted_iota_p.bind( |
| 742 | + dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout |
| 743 | + ) |
0 commit comments