Skip to content

Commit b09b077

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for fast upcasts of s8 to bf16 for vectors of 4 elements
To complement the current path that only handles 2 elements. PiperOrigin-RevId: 700998965
1 parent a158e02 commit b09b077

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,37 +1032,48 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
10321032
)
10331033
reg_type = self.registers.flat[0].type
10341034
is_vector_reg = ir.VectorType.isinstance(reg_type)
1035-
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else ()
1036-
if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,):
1035+
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
1036+
[vector_len] = reg_shape # This is meant to be a 1D assertion.
1037+
if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}:
10371038
new_registers = np.empty_like(self.registers)
1038-
for idx, reg in np.ndenumerate(self.registers):
1039-
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
1040-
val_16 = llvm.extractelement(reg_16, c(0, i32))
1039+
def upcast_to_bf16(reg, high):
10411040
# We first embed the s8 into a bf16 with the exponent equal to
10421041
# bias + mantissa bits. Then, we zero the msb that didn't fit into the
10431042
# mantissa, zero out all bits other than msb, and subtract the last
10441043
# two values from each other. This takes advantage of the fact that the
10451044
# lsb of the exponent (msb of the second byte) is zero, which allows us
10461045
# to losslesly pack the msb there. When 1, it doubles the value of s2,
10471046
# making the result negative.
1048-
new_val_32 = llvm.inline_asm(
1047+
return llvm.inline_asm(
10491048
i32,
1050-
[val_16],
1051-
"""
1052-
{
1049+
[reg],
1050+
f"""
1051+
{{
10531052
.reg .b32 s<3>;
1054-
prmt.b32 s0, $1, 0x43, 0x4140;
1053+
prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140};
10551054
and.b32 s1, s0, 0xff7fff7f;
10561055
and.b32 s2, s0, 0xff80ff80;
10571056
sub.bf16x2 $0, s1, s2;
1058-
}
1057+
}}
10591058
""",
10601059
"=r,r",
10611060
)
1062-
new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32))
1063-
new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32))
1061+
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32))
1062+
for idx, reg in np.ndenumerate(self.registers):
1063+
if vector_len == 2:
1064+
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
1065+
new_reg_32 = upcast_to_bf16(reg_16, high=False)
1066+
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
1067+
elif vector_len == 4:
1068+
reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg)
1069+
low = upcast_to_bf16(reg_32, high=False)
1070+
high = upcast_to_bf16(reg_32, high=True)
1071+
new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32))
1072+
new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32))
1073+
else:
1074+
raise NotImplementedError(vector_len)
10641075
new_registers[idx] = vector.bitcast(
1065-
ir.VectorType.get((2,), new_dtype), new_vec
1076+
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
10661077
)
10671078
return FragmentedArray(
10681079
_registers=new_registers, _layout=self.layout, _is_signed=is_signed

tests/mosaic/gpu_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,19 +1608,19 @@ def kernel(ctx, out, *_):
16081608

16091609
np.testing.assert_array_equal(result, x)
16101610

1611-
@parameterized.named_parameters(
1612-
("_bf16", jnp.bfloat16)
1613-
)
1614-
def test_fast_i8_convert(self, jax_dtype_to):
1615-
jax_dtype_to = jnp.dtype(jax_dtype_to)
1611+
@parameterized.parameters(2, 4)
1612+
def test_fast_i8_convert(self, reg_length):
1613+
jax_dtype_to = jnp.dtype(jnp.bfloat16)
16161614
jax_dtype_from = jnp.dtype(jnp.int8)
16171615
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
16181616
def kernel(ctx, inp, out, smem):
16191617
del ctx, smem
16201618
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
1619+
assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length]
16211620
arr.astype(mlir_dtype_to).store_untiled(out)
16221621

16231622
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
1623+
x = jnp.tile(x, reg_length // 2)
16241624
reference = x.astype(jax_dtype_to)
16251625

16261626
result = mgpu.as_gpu_kernel(

0 commit comments

Comments
 (0)