Skip to content

Commit 14da7eb

Browse files
petebuGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth. PiperOrigin-RevId: 698332564
1 parent 1afb05e commit 14da7eb

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,31 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15011501
return list(switch_op.results)
15021502

15031503

1504+
@register_lowering_rule(lax.bitcast_convert_type_p)
1505+
def _bitcast_convert_type_lowering_rule(
1506+
ctx: LoweringRuleContext, operand, *, new_dtype
1507+
):
1508+
# TODO(petebu) Handle case where src and dst types have different bitwidths
1509+
[operand_aval] = ctx.avals_in
1510+
operand = _ensure_fa(operand, operand_aval.dtype)
1511+
src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
1512+
dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
1513+
assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
1514+
assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
1515+
if src_elem_type.width != dst_elem_type.width:
1516+
raise NotImplementedError(
1517+
f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
1518+
" have different widths"
1519+
)
1520+
if ir.IntegerType.isinstance(dst_elem_type):
1521+
output_is_signed = mgpu_utils.is_signed(new_dtype)
1522+
else:
1523+
output_is_signed = None
1524+
return mgpu.FragmentedArray.bitcast(
1525+
operand, dst_elem_type, output_is_signed=output_is_signed
1526+
)
1527+
1528+
15041529
def _bcast(
15051530
x: ir.Value,
15061531
y: ir.Value,

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,15 +929,19 @@ def fast_instr(x):
929929
raise NotImplementedError(x.type)
930930
return fast_instr
931931

932-
def bitcast(self, elt: ir.Type):
932+
def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
933+
if elt == self.mlir_dtype:
934+
return self
933935
reg_type = self.registers.flat[0].type
934936
if ir.VectorType.isinstance(reg_type):
935937
reg_shape = ir.VectorType(reg_type).shape
936938
ty = ir.VectorType.get(reg_shape, elt)
937939
else:
938940
ty = elt
939941

940-
return self._pointwise(lambda x: arith.bitcast(ty, x))
942+
return self._pointwise(
943+
lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed
944+
)
941945

942946
def __getitem__(self, idx):
943947
if self.layout != WGMMA_LAYOUT:

tests/pallas/mosaic_gpu_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,30 @@ def kernel(x_ref, o_ref):
10521052
self.assertEqual(data.count('"name": "store"'), 2)
10531053
np.testing.assert_array_equal(y, x + x)
10541054

1055+
@parameterized.parameters(
1056+
(jnp.float16, jnp.float16), # Noop
1057+
(jnp.int16, jnp.bfloat16),
1058+
(jnp.int16, jnp.float16),
1059+
(jnp.uint16, jnp.float16),
1060+
(jnp.float32, jnp.int32),
1061+
(jnp.float32, jnp.uint32),
1062+
(jnp.uint32, jnp.int32),
1063+
(jnp.int32, jnp.uint32),
1064+
)
1065+
def test_bitcast_convert_type(self, in_dtype, out_dtype):
1066+
m, n = 16, 8
1067+
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
1068+
grid = ()
1069+
1070+
@functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid)
1071+
def convert(x_ref, y_ref):
1072+
y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape)
1073+
1074+
x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n))
1075+
y = convert(x)
1076+
y_ref = jax.lax.bitcast_convert_type(x, out_dtype)
1077+
np.testing.assert_array_equal(y, y_ref)
1078+
10551079

10561080
class PipelineTest(PallasTest):
10571081

0 commit comments

Comments
 (0)