Skip to content

Commit 09e023e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for all sensible conversions involving f8
PiperOrigin-RevId: 860078929
1 parent 9622105 commit 09e023e

File tree

2 files changed

+111
-120
lines changed

2 files changed

+111
-120
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,9 +2070,11 @@ def astype(
20702070
i8 = ir.IntegerType.get_signless(8)
20712071
i16 = ir.IntegerType.get_signless(16)
20722072
i32 = ir.IntegerType.get_signless(32)
2073-
bf16 = ir.BF16Type.get()
20742073
f32 = ir.F32Type.get()
2074+
f16 = ir.F16Type.get()
2075+
bf16 = ir.BF16Type.get()
20752076
f8e4m3fn = ir.Float8E4M3FNType.get()
2077+
f8e5m2 = ir.Float8E5M2Type.get()
20762078
f8e8m0fnu = ir.Float8E8M0FNUType.get()
20772079

20782080
cur_dtype = self.mlir_dtype
@@ -2430,23 +2432,21 @@ def pairwise_convert(do_convert):
24302432
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
24312433
)
24322434

2435+
# Here we handle all conversions involving f8 types.
2436+
# TODO(apaszke): Figure out proper satfinite and rounding modes.
2437+
supported_f8_f16 = {f8e4m3fn: f16, f8e5m2: f16, f8e8m0fnu: bf16}
2438+
f8_ptx_names = {f8e4m3fn: "e4m3", f8e5m2: "e5m2", f8e8m0fnu: "ue8m0"}
2439+
f16_ptx_names = {f16: "f16", bf16: "bf16"}
2440+
f8_types = f8_ptx_names.keys()
2441+
f16_types = f16_ptx_names.keys()
24332442
if f8e8m0fnu in {cur_dtype, new_dtype} and utils.get_arch().major < 10:
24342443
raise ValueError(
24352444
"f8e8m0fnu type only supported on Blackwell and newer GPUs"
24362445
)
2437-
if cur_dtype == f8e8m0fnu and new_dtype == bf16:
2438-
def do_convert(pair_vec):
2439-
return llvm.inline_asm(
2440-
i32,
2441-
[utils.bitcast(pair_vec, i16)],
2442-
"cvt.rn.bf16x2.ue8m0x2 $0, $1;",
2443-
"=r,h",
2444-
)
2445-
return pairwise_convert(do_convert)
2446-
# TODO(bchetioui): handle conversions to/from other float8 types.
2447-
if cur_dtype == f32 and new_dtype in {f8e4m3fn, f8e8m0fnu}:
2448-
tgt_ty = "e4m3" if new_dtype == f8e4m3fn else "ue8m0"
2449-
rounding = "rn" if new_dtype == f8e4m3fn else "rz"
2446+
# f8 <-> f32
2447+
if cur_dtype == f32 and new_dtype in f8_types:
2448+
name_8 = f8_ptx_names[new_dtype]
2449+
rounding = "rz" if new_dtype == f8e8m0fnu else "rn"
24502450
def do_convert(pair_vec):
24512451
e0, e1 = (
24522452
vector.extract(pair_vec, dynamic_position=[], static_position=[i])
@@ -2455,16 +2455,51 @@ def do_convert(pair_vec):
24552455
return llvm.inline_asm(
24562456
i16,
24572457
[e1, e0],
2458-
f"cvt.{rounding}.satfinite.{tgt_ty}x2.f32 $0, $1, $2;",
2458+
f"cvt.{rounding}.satfinite.{name_8}x2.f32 $0, $1, $2;",
24592459
"=h,r,r",
24602460
)
24612461
return pairwise_convert(do_convert)
2462-
2463-
if cur_dtype == f8e8m0fnu and new_dtype == f32:
2464-
return self.astype(bf16).astype(f32)
2465-
if cur_dtype == bf16 and new_dtype == f8e4m3fn:
2466-
# There are no instructions to convert bf16 to f8e4m3fn directly.
2467-
return self.astype(f32).astype(f8e4m3fn)
2462+
# No f8 type supports direct conversion to f32, so we go via 16-bit floats.
2463+
if cur_dtype in f8_types and new_dtype == f32:
2464+
return self.astype(supported_f8_f16[cur_dtype]).astype(f32)
2465+
# f8 <-> f16
2466+
if new_dtype in f8_types and cur_dtype == supported_f8_f16[new_dtype]:
2467+
name_16 = f16_ptx_names[cur_dtype]
2468+
name_8 = f8_ptx_names[new_dtype]
2469+
rounding = "rz" if new_dtype == f8e8m0fnu else "rn"
2470+
ptx = f"cvt.{rounding}.satfinite.{name_8}x2.{name_16}x2 $0, $1;"
2471+
def do_convert(pair_vec):
2472+
return llvm.inline_asm(i16, [utils.bitcast(pair_vec, i32)], ptx, "=h,r")
2473+
return pairwise_convert(do_convert)
2474+
if cur_dtype in f8_types and new_dtype == supported_f8_f16[cur_dtype]:
2475+
name_8 = f8_ptx_names[cur_dtype]
2476+
name_16 = f16_ptx_names[new_dtype]
2477+
ptx = f"cvt.rn.{name_16}x2.{name_8}x2 $0, $1;"
2478+
def do_convert(pair_vec):
2479+
return llvm.inline_asm(i32, [utils.bitcast(pair_vec, i16)], ptx, "=r,h")
2480+
return pairwise_convert(do_convert)
2481+
# We don't emulate the unsupported f8 <-> f16 conversions, but rather force
2482+
# the user to go via f32 to let them know it's expensive.
2483+
if (new_dtype in f8_types and cur_dtype in f16_types) or (
2484+
new_dtype in f16_types and cur_dtype in f8_types
2485+
):
2486+
# Remap the 16-bit type to the supported one.
2487+
ok_cur_dtype = supported_f8_f16.get(new_dtype, cur_dtype)
2488+
ok_new_dtype = supported_f8_f16.get(cur_dtype, new_dtype)
2489+
raise NotImplementedError(
2490+
f"Hardware has no support for converting from {cur_dtype} to"
2491+
f" {new_dtype} (only cast from {ok_cur_dtype} to {ok_new_dtype} is"
2492+
" supported). Cast to f32 first and then to the target type"
2493+
" (expensive, but sufficient)."
2494+
)
2495+
# Repack through a shared 16-bit type.
2496+
if cur_dtype in f8_types and new_dtype in f8_types:
2497+
if supported_f8_f16[cur_dtype] == supported_f8_f16[new_dtype]:
2498+
return self.astype(supported_f8_f16[cur_dtype]).astype(new_dtype)
2499+
raise NotImplementedError(
2500+
f"Conversion from {cur_dtype} to {new_dtype} must go through f32,"
2501+
" which is expensive. Cast to f32 explicitly if you really want it."
2502+
)
24682503

24692504
# Generic path.
24702505
from_float = isinstance(cur_dtype, ir.FloatType)

tests/mosaic/gpu_test.py

Lines changed: 55 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -662,56 +662,6 @@ def kernel(ctx, inp, out, smem):
662662
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
663663
np.testing.assert_array_equal(f(x), y)
664664

665-
@parameterized.parameters(
666-
(jnp.float32, jnp.float8_e4m3fn),
667-
(jnp.bfloat16, jnp.float8_e4m3fn)
668-
)
669-
def test_f8_conversions(self, jax_dtype_from, jax_dtype_to):
670-
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
671-
def kernel(ctx, inp, out, smem):
672-
del ctx
673-
smem_from, smem_to = smem
674-
copy(inp, smem_from, swizzle=128)
675-
t = mgpu.FragmentedArray.load_tiled(
676-
smem_from,
677-
swizzle=128,
678-
is_signed=None,
679-
layout=fa.WGMMA_LAYOUT,
680-
)
681-
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
682-
t.store_tiled(smem_to, swizzle=128)
683-
copy(smem_to, out, swizzle=128)
684-
685-
# These generative shenanigans are to ensure that we don't generate values
686-
# that are too large for the target type. That is because the saturation
687-
# behavior of the conversion is different between XLA and Mosaic GPU here
688-
# (to use the NVIDIA internal, we allow Mosaic GPU to use the .satfinite
689-
# modifier, which saturates to the largest finite value---while XLA would
690-
# give us NaNs in this case).
691-
max_finite_val = 0b111_1110
692-
693-
expected = jax.lax.bitcast_convert_type(
694-
jax.random.randint(
695-
jax.random.key(42),
696-
(1, 1, 64, 128),
697-
-max_finite_val,
698-
max_finite_val + 1,
699-
dtype=jnp.uint8,
700-
),
701-
jax_dtype_to,
702-
)
703-
x = expected.astype(jax_dtype_from)
704-
705-
res = mgpu.as_gpu_kernel(
706-
kernel,
707-
(1, 1, 1),
708-
(128, 1, 1),
709-
x,
710-
expected,
711-
(x, expected),
712-
)(x)
713-
np.testing.assert_array_equal(res, expected)
714-
715665
@parameterized.product(
716666
jax_dtype_from_to=(
717667
(jnp.int8, jnp.bfloat16),
@@ -3473,30 +3423,33 @@ def kernel(ctx, dst, _):
34733423
np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8))
34743424

34753425
@parameterized.product(
3476-
# TODO(apaszke): Add float16, float8_e5m2
3477-
jax_dtype_from=(jnp.float32, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e8m0fnu),
3478-
jax_dtype_to=(jnp.float32, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e8m0fnu),
3479-
# Test different vector lengths.
3426+
# TODO(apaszke): Add float16
3427+
jax_dtype_from=(jnp.float32, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float8_e8m0fnu),
3428+
jax_dtype_to=(jnp.float32, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.float8_e8m0fnu),
34803429
vec_len=(1, 2, 4, 8),
34813430
)
34823431
def test_conversion_f8_(self, jax_dtype_from, jax_dtype_to, vec_len):
34833432
from_bitwidth = jnp.finfo(jax_dtype_from).bits
34843433
to_bitwidth = jnp.finfo(jax_dtype_to).bits
34853434
if from_bitwidth > 8 and to_bitwidth > 8:
34863435
self.skipTest("At least one of the types should be 8-bit")
3487-
if from_bitwidth == to_bitwidth == 8:
3488-
self.skipTest("f8 <-> f8 conversions unimplemented")
3436+
if jax_dtype_from == jax_dtype_to:
3437+
self.skipTest("Identical types, so nothing to test")
34893438
if jnp.float8_e8m0fnu in {
34903439
jax_dtype_from,
34913440
jax_dtype_to,
34923441
} and not jtu.is_cuda_compute_capability_at_least("10.0"):
34933442
self.skipTest("f8e8m0fnu not supported on pre-Blackwell GPUs")
3494-
unimplemented = [
3495-
(jnp.float8_e4m3fn, jnp.bfloat16),
3496-
(jnp.float8_e4m3fn, jnp.float32),
3497-
(jnp.bfloat16, jnp.float8_e8m0fnu),
3498-
]
3499-
if (jax_dtype_from, jax_dtype_to) in unimplemented:
3443+
if from_bitwidth == to_bitwidth == 8 and {jax_dtype_from, jax_dtype_to} != {
3444+
jnp.float8_e4m3fn, jnp.float8_e5m2,
3445+
}:
3446+
self.skipTest("An unimplemented f8 <-> f8 conversion")
3447+
unimplemented = {
3448+
frozenset((jnp.float8_e4m3fn, jnp.bfloat16)),
3449+
frozenset((jnp.float8_e5m2, jnp.bfloat16)),
3450+
frozenset((jnp.float8_e8m0fnu, jnp.float16)),
3451+
}
3452+
if {jax_dtype_from, jax_dtype_to} in unimplemented:
35003453
self.skipTest("Unimplemented")
35013454
layout = fa.tmem_native_layout(vec_len)
35023455
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
@@ -3516,7 +3469,10 @@ def kernel(ctx, inp, out, smem):
35163469
bits = self.prng.integers(
35173470
low=sample_iinfo.min, high=sample_iinfo.max, size=(m, n), dtype=np.int32
35183471
).astype(int_sample_dtype)
3519-
values = jax.lax.bitcast_convert_type(bits, narrow_type).astype(jax_dtype_from)
3472+
values = jax.lax.bitcast_convert_type(bits, narrow_type)
3473+
# A bunch of conversions are only supported for finite values.
3474+
values = values.at[jnp.isinf(values)].set(jnp.finfo(narrow_type).max)
3475+
values = values.astype(jax_dtype_from)
35203476

35213477
expected = values.astype(jax_dtype_to)
35223478
res = mgpu.as_gpu_kernel(
@@ -3860,44 +3816,44 @@ def kernel(ctx, dst, _):
38603816
)
38613817
@jtu.thread_unsafe_test()
38623818
def test_max(self, vec_size, dtype):
3863-
def kernel(ctx, src, src2, dst, _):
3864-
is_signed = utils.is_signed(dtype)
3865-
src = fa.FragmentedArray.load_strided(src, vec_size=vec_size, is_signed=is_signed)
3866-
src2 = fa.FragmentedArray.load_strided(src2, vec_size=vec_size, is_signed=is_signed)
3867-
src.max(src2).store_untiled(dst)
3868-
x = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype)
3869-
y = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype)
3870-
f = mgpu.as_gpu_kernel(
3819+
def kernel(ctx, src, src2, dst, _):
3820+
is_signed = utils.is_signed(dtype)
3821+
src = fa.FragmentedArray.load_strided(src, vec_size=vec_size, is_signed=is_signed)
3822+
src2 = fa.FragmentedArray.load_strided(src2, vec_size=vec_size, is_signed=is_signed)
3823+
src.max(src2).store_untiled(dst)
3824+
x = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype)
3825+
y = self.prng.uniform(-1, 1, (12 * 128,)).astype(dtype)
3826+
f = mgpu.as_gpu_kernel(
38713827
kernel, (1, 1, 1), (128, 1, 1), (x, y), x, ()
38723828
)
3873-
with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), self.capture_stdout() as ptx:
3874-
z = f(x, y).block_until_ready()
3875-
if dtype == jnp.float32:
3876-
dtype_short = "f32"
3877-
elif dtype == jnp.float16:
3878-
dtype_short = "f16"
3879-
elif dtype == jnp.bfloat16:
3880-
dtype_short = "bf16"
3881-
elif jnp.issubdtype(dtype, jnp.signedinteger):
3882-
dtype_short = f"s{dtypes.itemsize_bits(dtype)}"
3883-
elif jnp.issubdtype(dtype, jnp.unsignedinteger):
3884-
dtype_short = f"u{dtypes.itemsize_bits(dtype)}"
3885-
else:
3886-
raise NotImplementedError(f"Unsupported dtype: {dtype}")
3887-
ptx = ptx()
3888-
nan_modifier = ".NaN" if jnp.issubdtype(dtype, jnp.floating) else ""
3889-
instr = f"max{nan_modifier}.{dtype_short} "
3890-
instr_double = f"max{nan_modifier}.{dtype_short}x2 "
3891-
single_converts = ptx.count(instr)
3892-
double_converts = ptx.count(instr_double)
3893-
self.assertEqual(128 * (single_converts + 2 * double_converts), 12 * 128)
3894-
if vec_size % 2:
3895-
self.assertGreater(single_converts, 0)
3896-
elif dtypes.itemsize_bits(dtype) < 32:
3897-
# This, together with the assertion above, implies that all converts
3898-
# happened through doubled operations.
3899-
self.assertEqual(single_converts, 0)
3900-
np.testing.assert_array_equal(z, np.maximum(x, y))
3829+
with jtu.set_env(MOSAIC_GPU_DUMP_PTX="1"), self.capture_stdout() as ptx:
3830+
z = f(x, y).block_until_ready()
3831+
if dtype == jnp.float32:
3832+
dtype_short = "f32"
3833+
elif dtype == jnp.float16:
3834+
dtype_short = "f16"
3835+
elif dtype == jnp.bfloat16:
3836+
dtype_short = "bf16"
3837+
elif jnp.issubdtype(dtype, jnp.signedinteger):
3838+
dtype_short = f"s{dtypes.itemsize_bits(dtype)}"
3839+
elif jnp.issubdtype(dtype, jnp.unsignedinteger):
3840+
dtype_short = f"u{dtypes.itemsize_bits(dtype)}"
3841+
else:
3842+
raise NotImplementedError(f"Unsupported dtype: {dtype}")
3843+
ptx = ptx()
3844+
nan_modifier = ".NaN" if jnp.issubdtype(dtype, jnp.floating) else ""
3845+
instr = f"max{nan_modifier}.{dtype_short} "
3846+
instr_double = f"max{nan_modifier}.{dtype_short}x2 "
3847+
single_converts = ptx.count(instr)
3848+
double_converts = ptx.count(instr_double)
3849+
self.assertEqual(128 * (single_converts + 2 * double_converts), 12 * 128)
3850+
if vec_size % 2:
3851+
self.assertGreater(single_converts, 0)
3852+
elif dtypes.itemsize_bits(dtype) < 32:
3853+
# This, together with the assertion above, implies that all converts
3854+
# happened through doubled operations.
3855+
self.assertEqual(single_converts, 0)
3856+
np.testing.assert_array_equal(z, np.maximum(x, y))
39013857

39023858
def test_splat_layout(self):
39033859
m, n = 64, 8

0 commit comments

Comments
 (0)