Skip to content

Commit 3649da5

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time. PiperOrigin-RevId: 737626600
1 parent 0ff2340 commit 3649da5

File tree

4 files changed

+81
-31
lines changed

4 files changed

+81
-31
lines changed

jax/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ pytype_strict_library(
799799
)
800800

801801
# This target only supports sm_90 GPUs.
802-
py_library(
802+
py_library_providing_imports_info(
803803
name = "mosaic_gpu",
804804
srcs = glob(["experimental/mosaic/gpu/*.py"]),
805805
visibility = [
@@ -824,6 +824,7 @@ py_library(
824824
"//jaxlib/mlir:pass_manager",
825825
"//jaxlib/mlir:scf_dialect",
826826
"//jaxlib/mlir:vector_dialect",
827+
"//jaxlib/mosaic/python:gpu_dialect",
827828
] + py_deps("absl/flags") + py_deps("numpy"),
828829
)
829830

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,11 +1244,10 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
12441244
is_vector_reg = ir.VectorType.isinstance(reg_type)
12451245
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
12461246
[vector_len] = reg_shape # This is meant to be a 1D assertion.
1247-
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
1247+
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
12481248
new_registers = np.empty_like(self.registers)
1249-
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
1249+
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
12501250
for idx, reg in np.ndenumerate(self.registers):
1251-
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
12521251
# The algorithm here is largely the same as CUTLASS's
12531252
# NumericArrayConverter specialization for int4 -> bf16 casts.
12541253
# We modify it slightly, because we only extract 2 values.
@@ -1262,25 +1261,41 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
12621261
# positive int4s will end up larger than negative int4s, with a bias of
12631262
# 8. Use use the sub to subtract the base (our initial exponent) and the
12641263
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
1265-
new_reg_32 = llvm.inline_asm(
1266-
i32,
1267-
[reg_8],
1268-
"""
1269-
{
1270-
.reg .b32 s<4>;
1271-
shr.s32 s0, $1, 4;
1272-
prmt.b32 s1, $1, s0, 0xF4F0;
1273-
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
1274-
mov.b32 s3, 0x43084308;
1275-
sub.bf16x2 $0, s2, s3;
1276-
}
1277-
""",
1278-
"=r,r",
1279-
)
1280-
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
1281-
new_registers[idx] = vector.bitcast(
1282-
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
1283-
)
1264+
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
1265+
assert 0 <= part < 4
1266+
return llvm.inline_asm(
1267+
i32,
1268+
[reg, reg_shr],
1269+
f"""
1270+
{{
1271+
.reg .b32 s<4>;
1272+
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
1273+
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
1274+
mov.b32 s3, 0x43084308;
1275+
sub.bf16x2 $0, s2, s3;
1276+
}}
1277+
""",
1278+
"=r,r,r",
1279+
)
1280+
offset = 0
1281+
out_int_regs = []
1282+
for group_size in (8, 4, 2):
1283+
int_ty = ir.IntegerType.get_signless(group_size * 4)
1284+
while vector_len - offset >= group_size:
1285+
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
1286+
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
1287+
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
1288+
out_int_regs.extend(
1289+
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
1290+
for part in range(group_size // 2)
1291+
)
1292+
offset += group_size
1293+
assert offset == vector_len
1294+
out_vec_int = utils.vector_concat([
1295+
vector.splat(ir.VectorType.get((1,), i32), reg)
1296+
for reg in out_int_regs
1297+
])
1298+
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
12841299
return FragmentedArray(
12851300
_registers=new_registers, _layout=self.layout, _is_signed=None
12861301
)

jax/experimental/mosaic/gpu/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type):
348348
return ir.FloatType(ty).width
349349
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
350350
return MBARRIER_BYTES * 8
351+
if ir.VectorType.isinstance(ty):
352+
vty = ir.VectorType(ty)
353+
return math.prod(vty.shape) * bitwidth(vty.element_type)
351354
raise NotImplementedError(ty)
352355

353356

@@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type):
12201223
x_ty = ir.IntegerType(x.type)
12211224
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
12221225
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
1226+
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
1227+
x_ty = ir.VectorType(x.type)
1228+
new_ty = ir.VectorType(new_type)
1229+
if bitwidth(x_ty) != bitwidth(new_ty):
1230+
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
1231+
return vector.bitcast(new_type, x)
12231232
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
12241233

12251234

@@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice):
12391248
elem = llvm.extractelement(v, c(src, i32))
12401249
result = llvm.insertelement(result, elem, c(tgt, i32))
12411250
return result
1251+
1252+
1253+
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
1254+
index = ir.IndexType.get()
1255+
if not vectors:
1256+
raise ValueError("Cannot concatenate an empty list of vectors")
1257+
vty = vectors[0].type
1258+
if not ir.VectorType.isinstance(vty):
1259+
raise ValueError("Cannot concatenate non-vector values")
1260+
if vty.rank != 1:
1261+
raise NotImplementedError("Only 1D vectors are supported")
1262+
for v in vectors:
1263+
if v.type != vty:
1264+
raise ValueError("Cannot concatenate vectors of different types")
1265+
result = llvm.mlir_undef(
1266+
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
1267+
)
1268+
offset = 0
1269+
for v in vectors:
1270+
for i in range(vty.shape[0]):
1271+
elem = vector.extractelement(v, position=c(i, index))
1272+
result = vector.insertelement(elem, result, position=c(offset + i, index))
1273+
offset += vty.shape[0]
1274+
return result

tests/mosaic/gpu_test.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -518,14 +518,15 @@ def kernel(ctx, out, smem):
518518
)()
519519
np.testing.assert_array_equal(iota, expected)
520520

521-
@parameterized.named_parameters(
522-
("bf16_i8", jnp.bfloat16, jnp.int8),
523-
("i8_bf16", jnp.int8, jnp.bfloat16),
524-
("i8_i8", jnp.int8, jnp.int8),
525-
("i4_i4", jnp.int4, jnp.int4),
526-
("i4_bf16", jnp.int4, jnp.bfloat16),
521+
@parameterized.product(
522+
jax_dtype_from_to=(
523+
(jnp.int8, jnp.bfloat16),
524+
(jnp.int4, jnp.bfloat16),
525+
),
526+
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
527527
)
528-
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
528+
def test_optimized_conversion(self, jax_dtype_from_to, layout):
529+
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
529530
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
530531
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
531532
m = 128
@@ -538,7 +539,7 @@ def kernel(ctx, inp, out, smem):
538539
smem_from,
539540
swizzle=128,
540541
is_signed=utils.is_signed(jax_dtype_from),
541-
layout=fa._tiled_wgmma_layout((m, n))
542+
layout=layout,
542543
)
543544
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
544545
t.store_tiled(smem_to, swizzle=128)

0 commit comments

Comments
 (0)