Skip to content

Commit 34cd5b0

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Remove sub-byte conversion restriction
XLA:GPU recently changed its endianness to little endian to better match LLVM and the rest of the CUDA ecosystem, so we can lift the earlier restrictions. PiperOrigin-RevId: 737934373
1 parent 549973d commit 34cd5b0

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,11 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
12441244
is_vector_reg = ir.VectorType.isinstance(reg_type)
12451245
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
12461246
[vector_len] = reg_shape # This is meant to be a 1D assertion.
1247+
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
1248+
raise ValueError(
1249+
"Register bitwidth in target type must be divisible by 8, got"
1250+
f" {new_reg_bitwidth}"
1251+
)
12471252
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
12481253
new_registers = np.empty_like(self.registers)
12491254
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
@@ -1344,11 +1349,6 @@ def upcast_to_bf16(reg, high):
13441349
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
13451350
)
13461351
# Generic path.
1347-
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
1348-
# same endianness as the target machine (which is little for NVIDIA GPUs).
1349-
# We'll need to add specialized casting routines that flip the endianness.
1350-
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
1351-
raise NotImplementedError("Conversion involving sub-byte types unsupported")
13521352
from_float = ir.FloatType.isinstance(cur_dtype)
13531353
to_float = ir.FloatType.isinstance(new_dtype)
13541354
from_integer = ir.IntegerType.isinstance(cur_dtype)

tests/mosaic/gpu_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,25 @@ def kernel(ctx, out, smem):
518518
)()
519519
np.testing.assert_array_equal(iota, expected)
520520

521+
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
522+
def test_sub_byte_conversion(self, jax_dtype_to):
523+
jax_dtype_from = jnp.int4
524+
def kernel(ctx, inp, out, smem):
525+
del ctx # Unused.
526+
smem_inp, smem_out = smem
527+
copy(inp, smem_inp, swizzle=16)
528+
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
529+
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
530+
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
531+
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
532+
533+
x = self.prng.integers(
534+
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
535+
).astype(jax_dtype_from)
536+
y = x.astype(jax_dtype_to)
537+
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
538+
np.testing.assert_array_equal(f(x), y)
539+
521540
@parameterized.product(
522541
jax_dtype_from_to=(
523542
(jnp.int8, jnp.bfloat16),

0 commit comments

Comments
 (0)