Skip to content

Commit 69e3f0d

Browse files
petebuGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Add test for FragmentedArray.bitcast.
PiperOrigin-RevId: 699919048
1 parent b372ce4 commit 69e3f0d

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,8 @@ def __init__(
463463

464464
if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype):
465465
raise TypeError(
466-
"is_signed must only be non-None if the MLIR type is an integer"
467-
f" type, got {_is_signed=} for {self.mlir_dtype}"
466+
"is_signed must be non-None if and only if the MLIR type is an"
467+
f" integer type, got {_is_signed=} for {self.mlir_dtype}"
468468
)
469469

470470
match self.layout:
@@ -962,6 +962,12 @@ def fast_instr(x):
962962
return fast_instr
963963

964964
def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
965+
if (output_is_signed is not None) != ir.IntegerType.isinstance(elt):
966+
raise TypeError(
967+
"output_is_signed must be non-None if and only if the MLIR type is an"
968+
f" integer type, got {output_is_signed=} for {elt}"
969+
)
970+
965971
if elt == self.mlir_dtype:
966972
return self
967973
reg_type = self.registers.flat[0].type

tests/mosaic/gpu_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,40 @@ def kernel(ctx, _):
15771577

15781578
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)()
15791579

1580+
@parameterized.parameters(
1581+
(jnp.float16, jnp.float16), # Noop
1582+
(jnp.int16, jnp.bfloat16),
1583+
(jnp.int16, jnp.float16),
1584+
(jnp.uint16, jnp.float16),
1585+
(jnp.float32, jnp.int32),
1586+
(jnp.float32, jnp.uint32),
1587+
(jnp.uint32, jnp.int32),
1588+
(jnp.int32, jnp.uint32),
1589+
)
1590+
def test_bitcast(self, in_dtype, out_dtype):
1591+
out_ir_type = utils.dtype_to_ir_type(out_dtype)
1592+
in_is_signed = utils.is_signed(in_dtype)
1593+
out_is_signed = utils.is_signed(out_dtype)
1594+
1595+
def kernel(ctx, inp, out, smem):
1596+
del ctx, smem
1597+
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed)
1598+
arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed)
1599+
arr.store_untiled(out)
1600+
1601+
x = jnp.arange(256, dtype=in_dtype)
1602+
reference = jax.lax.bitcast_convert_type(x, out_dtype)
1603+
1604+
result = mgpu.as_gpu_kernel(
1605+
kernel,
1606+
(1, 1, 1),
1607+
(128, 1, 1),
1608+
x,
1609+
reference,
1610+
None,
1611+
)(x)
1612+
np.testing.assert_array_equal(result, reference)
1613+
15801614

15811615
class ProfilerTest(TestCase):
15821616

0 commit comments

Comments
 (0)