Skip to content

Commit 563cd58

Browse files
[MLIR] Fix construction of i1 attributes from numpy arrays and add tests.
This change fixes the construction of MLIR attributes from `numpy` arrays for a Boolean `dtype`. This happened when calling `jax.mlir.ir_attribute` and the problem occurred in `_numpy_array_attribute`: that function *first* bit-packed the Booleans and *then* determined the `dtype` from the result, which is `uint8` at that point rather than `bool`. The fix thus consists of a simple swapping of lines. Note that the virtually equivalent `_numpy_array_constant` already used the right order. The change extends and improves the tests for the FFI primitive, which uses `jax.mlir.ir_attribute` under the hood. The previous tests use a flawed comparison to verify the round-tripped type: it checked whether the *Python* type of the created MLIR attribute is the expected one. However, that is not enough: two `DenseElementAttr`s with different element types pass that check but really should not. That check would not have caught the bug fixed in this change (plus it didn't test the variant with `bool`s). The new version uses strings instead, which is not only more precise but also more concise in building the expected values. In addition to this improvement, the change adds instances of the corresponding FFI test for essentially all attribute types I could find. PiperOrigin-RevId: 834770477
1 parent ecd0b33 commit 563cd58

File tree

2 files changed

+56
-17
lines changed

2 files changed

+56
-17
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,11 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute:
364364
raise TypeError(f"Unsupported scalar attribute type: {type(val)}")
365365

366366
def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute:
367+
element_type = dtype_to_ir_type(x.dtype)
367368
shape = x.shape
368369
if x.dtype == np.bool_:
369370
x = np.packbits(x, bitorder='little') # type: ignore
370371
x = np.ascontiguousarray(x)
371-
element_type = dtype_to_ir_type(x.dtype)
372372
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
373373

374374
def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute:

tests/ffi_test.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from jax._src import config
2929
from jax._src import core
3030
from jax._src import dispatch
31+
from jax._src import dtypes
3132
from jax._src import test_util as jtu
3233
from jax._src.interpreters import mlir
3334
from jax._src.layout import Layout
@@ -85,29 +86,67 @@ def lowering_rule(ctx, x):
8586
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
8687
self.assertRegex(text, pattern)
8788

88-
@parameterized.parameters([
89-
(True, mlir.ir.BoolAttr.get),
90-
(1, mlir.i64_attr),
91-
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
92-
("param", mlir.ir.StringAttr.get),
93-
(np.float32(0.5),
94-
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
95-
])
96-
def test_params(self, param, expected_builder):
89+
# Concise helpers to every test instance below in one line.
90+
_arr = lambda value, dtype=None: np.array(value, dtype=dtype)
91+
_ftens1 = lambda et: f"dense<1.000000e+00> : tensor<{et}>"
92+
_itens1 = lambda et: f"dense<1> : tensor<{et}>"
93+
94+
@parameterized.parameters(
95+
(_arr(1, dtypes.int2), _itens1("i2")),
96+
(_arr(1, dtypes.int4), _itens1("i4")),
97+
(_arr(1, dtypes.uint2), _itens1("ui2")),
98+
(_arr(1, dtypes.uint4), _itens1("ui4")),
99+
(_arr(1, np.int16), _itens1("i16")),
100+
(_arr(1, np.int32), _itens1("i32")),
101+
(_arr(1, np.int64), _itens1("i64")),
102+
(_arr(1, np.int8), _itens1("i8")),
103+
(_arr(1, np.uint16), _itens1("ui16")),
104+
(_arr(1, np.uint32), _itens1("ui32")),
105+
(_arr(1, np.uint64), _itens1("ui64")),
106+
(_arr(1, np.uint8), _itens1("ui8")),
107+
(_arr(1.0, dtypes.bfloat16), _ftens1("bf16")),
108+
(_arr(1.0, dtypes.float4_e2m1fn), _ftens1("f4E2M1FN")),
109+
(_arr(1.0, dtypes.float8_e3m4), _ftens1("f8E3M4")),
110+
(_arr(1.0, dtypes.float8_e4m3), _ftens1("f8E4M3")),
111+
(_arr(1.0, dtypes.float8_e4m3b11fnuz), _ftens1("f8E4M3B11FNUZ")),
112+
(_arr(1.0, dtypes.float8_e4m3fn), _ftens1("f8E4M3FN")),
113+
(_arr(1.0, dtypes.float8_e4m3fnuz), _ftens1("f8E4M3FNUZ")),
114+
(_arr(1.0, dtypes.float8_e5m2), _ftens1("f8E5M2")),
115+
(_arr(1.0, dtypes.float8_e5m2fnuz), _ftens1("f8E5M2FNUZ")),
116+
(_arr(1.0, dtypes.float8_e8m0fnu), _ftens1("f8E8M0FNU")),
117+
(_arr(1.0, np.bool), "dense<true> : tensor<i1>"),
118+
(_arr(1.0, np.float16), _ftens1("f16")),
119+
(_arr(1.0, np.float32), _ftens1("f32")),
120+
(_arr(1.0, np.float64), _ftens1("f64")),
121+
(dtypes.bfloat16(1.0), "1.000000e+00 : bf16"),
122+
(np.bool(False), "false"),
123+
(np.bool(True), "true"),
124+
(np.float16(1.0), "1.000000e+00 : f16"),
125+
(np.float32(1.0), "1.000000e+00 : f32"),
126+
(np.float64(1.0), "1.000000e+00 : f64"),
127+
(np.int16(1), "1 : i16"),
128+
(np.int32(1), "1 : i32"),
129+
(np.int64(1), "1 : i64"),
130+
(np.int8(1), "1 : i8"),
131+
(np.uint16(1), "1 : ui16"),
132+
(np.uint32(1), "1 : ui32"),
133+
(np.uint64(1), "1 : ui64"),
134+
(np.uint8(1), "1 : ui8"),
135+
(np.zeros((), dtype=dtypes.float0), "dense<false> : tensor<i1>"),
136+
("param", '"param"'),
137+
)
138+
def test_params(self, param, expected_str):
97139
def fun(x):
98140
return jax.ffi.ffi_call("test_ffi", x)(x, param=param)
99141

100142
# Here we inspect the lowered IR to test that the parameter has been
101143
# serialized with the appropriate type.
102144
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
103145
op = self.find_custom_call_in_module(module)
104-
config = op.attributes["mhlo.backend_config"]
105-
self.assertIsInstance(config, mlir.ir.DictAttr)
106-
self.assertIn("param", config)
107-
with mlir.make_ir_context(), mlir.ir.Location.unknown():
108-
expected = expected_builder(param)
109-
self.assertEqual(type(config["param"]), type(expected))
110-
self.assertTrue(expected.type.isinstance(config["param"].type))
146+
conf = op.attributes["mhlo.backend_config"]
147+
self.assertIsInstance(conf, mlir.ir.DictAttr)
148+
self.assertIn("param", conf)
149+
self.assertEqual(str(conf["param"]), expected_str)
111150

112151
def test_token(self):
113152
def fun():

0 commit comments

Comments
 (0)