Skip to content

Commit 68771b2

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add a test for tcgen05 MMA that's both sparse and block scaled
No changes were necessary, but I wanted to make sure it works in Pallas (and it did). PiperOrigin-RevId: 872382107
1 parent a58b181 commit 68771b2

File tree

3 files changed

+237
-18
lines changed

3 files changed

+237
-18
lines changed

jax/experimental/mosaic/gpu/tcgen05.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,6 @@ def mma(
311311
f" accumulators, but got: {d.dtype}"
312312
)
313313
elif any(isinstance(element_type, t) for t in {ir.Float4E2M1FNType}):
314-
if is_sparse:
315-
raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn")
316314
if not is_scaled:
317315
raise ValueError(
318316
f"MMA with element type {element_type} only supports block scaling"
@@ -426,10 +424,13 @@ def mma(
426424
a_sparse_metadata = cast(TMEMRef, a_sparse_metadata)
427425
if n % 32:
428426
raise ValueError(f"Sparse MMA requires N to be divisible by 32, got: {n}")
429-
if a_sparse_metadata.shape != (m, k // 2):
427+
sparse_group_elems = 8 if utils.bitwidth(element_type) == 4 else 4
428+
# Each sparse group has 2 entries.
429+
expected_meta_k = k // sparse_group_elems * 2
430+
if a_sparse_metadata.shape != (m, expected_meta_k):
430431
raise ValueError(
431-
f"A sparse metadata shape mismatch: expected {(m, k // 2)}, got"
432-
f" {a_sparse_metadata.shape}"
432+
f"A sparse metadata shape mismatch: expected {(m, expected_meta_k)},"
433+
f" got {a_sparse_metadata.shape}"
433434
)
434435
if a_sparse_metadata.dtype != ir.IntegerType.get_signless(2):
435436
raise ValueError(
@@ -508,8 +509,9 @@ def mma(
508509
if a_sparse_addr_base is not None:
509510
if n_groups != 1 or m_groups != 1:
510511
raise NotImplementedError("A sparse metadata address calculation for multiple tiles")
511-
assert k_group_elems % 32 == 0
512-
cols_per_k_group = k_group_elems // 32
512+
sparse_group_elems = 8 if utils.bitwidth(mma_element_type) == 4 else 4
513+
# Each sparse group has 2 entries, each TMEM column holds 16 i2 entries.
514+
cols_per_k_group = k_group_elems // sparse_group_elems * 2 // 16
513515
a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32))
514516
else:
515517
a_sparse_addr = None
@@ -611,11 +613,11 @@ def _do_mma(
611613
sparse=is_sparse,
612614
)
613615
elif isinstance(element_type, ir.Float4E2M1FNType):
614-
assert not is_sparse
615616
assert not a_transpose and not b_transpose
616617
create_scaled_instr_descriptor = functools.partial(
617618
create_scaled_f4_instr_descriptor,
618619
scale_type=scale_element_type,
620+
sparse=is_sparse,
619621
)
620622
if scale_element_type == ir.Float8E8M0FNUType.get():
621623
kind = "mxf4.block_scale.scale_vec::2X"
@@ -674,18 +676,14 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
674676
return offset >> 4
675677
for k_step in range(k // instr_k):
676678
if is_sparse:
677-
assert 32 <= instr_k <= 64
678-
selector_width = instr_k
679-
k_steps_for_col_inc = 64 // selector_width
680-
assert (k // instr_k) % k_steps_for_col_inc == 0
681-
sp_selector = k_step % k_steps_for_col_inc
682-
# If the K group is large, we need to increment the sparse metadata.
683-
# TODO(apaszke): At this point the purpose of this function is becoming
684-
# less clear, since we end up replicating address arithmetic that's
685-
# already there in the caller. We should unify them into a single loop.
679+
sparse_group_elems = 8 if elem_bitwidth == 4 else 4
680+
# Each sparse group has 2 entries, each TMEM column holds 16 i2 entries.
681+
meta_cols_per_instr = instr_k // sparse_group_elems * 2 // 16
682+
instrs_per_col_pair = 2 // meta_cols_per_instr
683+
sp_selector = k_step % instrs_per_col_pair
686684
sparse_addr = (
687685
arith.addi(
688-
a_sparse_addr, utils.c(k_step // k_steps_for_col_inc * 2, i32)
686+
a_sparse_addr, utils.c(k_step // instrs_per_col_pair * 2, i32)
689687
),
690688
)
691689
if is_scaled:

tests/mosaic/gpu_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,6 +2525,126 @@ def format_scales(scales):
25252525
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
25262526
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)
25272527

2528+
@parameterized.product(
2529+
in_jax_dtype=(jnp.float4_e2m1fn,),
2530+
scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn),
2531+
m=(128,),
2532+
n=(128, 256),
2533+
swizzle=(128,),
2534+
)
2535+
def test_mma_block_scaled_sparse_f4(self, m, n, in_jax_dtype, scale_jax_dtype, swizzle):
2536+
out_jax_dtype = jnp.float32
2537+
sparse_meta_dtype = jnp.uint2
2538+
if scale_jax_dtype == jnp.float8_e8m0fnu:
2539+
block_size = 64
2540+
elif scale_jax_dtype == jnp.float8_e4m3fn:
2541+
block_size = 32
2542+
k_steps = 2
2543+
2544+
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
2545+
swizzle_elems = 8 * swizzle // bitwidth(in_mlir_dtype)
2546+
k = swizzle_elems * k_steps
2547+
lhs_tiling = rhs_tiling = (8, swizzle_elems)
2548+
2549+
def kernel(ctx, lhs, rhs, lhs_sparse_gmem, lhs_scales_gmem, rhs_scales_gmem, out, scratch):
2550+
(
2551+
lhs_smem, rhs_smem, lhs_sparse_smem,
2552+
lhs_scales_smem, rhs_scales_smem,
2553+
barriers, mma_barrier, acc, lhs_sparse, lhs_scales, rhs_scales,
2554+
) = scratch
2555+
operand_kwargs = dict(
2556+
swizzle=swizzle,
2557+
gmem_transform=mgpu.TileTransform(lhs_tiling),
2558+
)
2559+
ctx.async_copy(src_ref=lhs, dst_ref=lhs_smem, barrier=barriers[0], **operand_kwargs)
2560+
ctx.async_copy(src_ref=rhs, dst_ref=rhs_smem, barrier=barriers[1], swizzle=swizzle, gmem_transform=mgpu.TileTransform(rhs_tiling))
2561+
ctx.async_copy(src_ref=lhs_sparse_gmem, dst_ref=lhs_sparse_smem, barrier=barriers[2])
2562+
ctx.async_copy(src_ref=lhs_scales_gmem, dst_ref=lhs_scales_smem, barrier=barriers[3])
2563+
ctx.async_copy(src_ref=rhs_scales_gmem, dst_ref=rhs_scales_smem, barrier=barriers[4])
2564+
for i in range(5):
2565+
barriers[i].wait()
2566+
with mgpu.single_thread():
2567+
tcgen05.async_copy_sparse_metadata_smem_to_tmem(lhs_sparse_smem, lhs_sparse)
2568+
tcgen05.async_copy_scales_smem_to_tmem(lhs_scales_smem, lhs_scales)
2569+
tcgen05.async_copy_scales_smem_to_tmem(rhs_scales_smem, rhs_scales)
2570+
tcgen05.mma(
2571+
acc,
2572+
lhs_smem,
2573+
mgpu.memref_transpose(rhs_smem, (1, 0, 3, 2)),
2574+
a_swizzle=swizzle,
2575+
b_swizzle=swizzle,
2576+
a_scale=lhs_scales,
2577+
b_scale=rhs_scales,
2578+
a_sparse_metadata=lhs_sparse,
2579+
accumulate=False,
2580+
)
2581+
tcgen05.commit_arrive(mma_barrier)
2582+
mma_barrier.wait(orders_tensor_core=True)
2583+
acc.load().store_untiled(out, optimized=False)
2584+
2585+
x_shape = (m, k // 2)
2586+
x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype)
2587+
y_shape = (n, k)
2588+
y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype)
2589+
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
2590+
meta_k = k // 4
2591+
scratch_shape = [
2592+
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
2593+
jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype),
2594+
jax.ShapeDtypeStruct((m // 128, meta_k // 64, 128, 64), sparse_meta_dtype),
2595+
jax.ShapeDtypeStruct((m // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
2596+
jax.ShapeDtypeStruct((n // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
2597+
mgpu.TMABarrier(5),
2598+
mgpu.Barrier(1),
2599+
mgpu.TMEM((m, n), out_jax_dtype),
2600+
mgpu.TMEM((m, meta_k), sparse_meta_dtype, layout=tcgen05.sparse_meta_layout()),
2601+
mgpu.TMEM((m, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
2602+
mgpu.TMEM((n, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
2603+
]
2604+
n_groups = k // 8
2605+
index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2)
2606+
valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]]
2607+
assert len(valid_pairs) == 6
2608+
x_pairs = jax.random.randint(jax.random.key(1234), (m, n_groups), 0, 6, dtype=jnp.uint8)
2609+
x_sparse = valid_pairs[x_pairs]
2610+
assert x_sparse.shape == (m, n_groups, 2)
2611+
def format_sparse_meta(meta):
2612+
mn, groups, _2 = meta.shape
2613+
assert _2 == 2
2614+
k_meta = groups * 2
2615+
meta_tiled = (
2616+
meta.reshape(mn // 128, 128, k_meta // 64, 64).transpose(0, 2, 1, 3)
2617+
)
2618+
return (
2619+
meta_tiled.reshape(mn // 128, k_meta // 64, 128, 64)
2620+
.astype(sparse_meta_dtype)
2621+
)
2622+
x_gpu_sparse = format_sparse_meta(x_sparse)
2623+
a_scales, b_scales = self._sample_scales(m, k, n, block_size, scale_jax_dtype)
2624+
def format_scales(scales):
2625+
mn, k = scales.shape
2626+
assert mn % 128 == 0 and k % 4 == 0, scales.shape
2627+
return (
2628+
scales.reshape(mn // 128, 4, 32, k // 4, 4)
2629+
.transpose(0, 3, 2, 1, 4)
2630+
.reshape(mn // 128, k // 4, 32, 16)
2631+
)
2632+
a_gpu_scales, b_gpu_scales = map(format_scales, (a_scales, b_scales))
2633+
args = (x, y, x_gpu_sparse, a_gpu_scales, b_gpu_scales)
2634+
z = mgpu.as_gpu_kernel(
2635+
kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape
2636+
)(*args)
2637+
# 4-bit sparse data is filled in pairs of elements.
2638+
x32 = x.astype(np.float32).reshape(m, n_groups, 2, 2)
2639+
x_logical32 = np.zeros((m, n_groups, 4, 2), dtype=np.float32)
2640+
np.put_along_axis(x_logical32, x_sparse[..., np.newaxis], x32, axis=-2)
2641+
x_logical32 = x_logical32.reshape(m, k)
2642+
y32 = y.astype(np.float32)
2643+
a_logical_scales = jnp.repeat(a_scales, block_size, axis=1).astype(jnp.float32)
2644+
b_logical_scales = jnp.repeat(b_scales, block_size, axis=1).astype(jnp.float32)
2645+
ref = (x_logical32 * a_logical_scales) @ (y32 * b_logical_scales).T
2646+
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)
2647+
25282648
@parameterized.product(
25292649
lhs_transpose=(False, True),
25302650
rhs_transpose=(False, True),

tests/pallas/mosaic_gpu_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,6 +4287,107 @@ def kernel(a_smem, b_smem, a_sparse_smem, out_ref,
42874287
ref = x_logical.astype(jnp.float32) @ y.T.astype(jnp.float32)
42884288
np.testing.assert_allclose(z, ref, atol=7e-5, rtol=5e-6)
42894289

4290+
@parameterized.product(
4291+
m=[128],
4292+
n=[128, 256],
4293+
)
4294+
def test_block_scaled_sparse_matmul(self, m, n):
4295+
self.skip_if_wg_semantics()
4296+
in_dtype = jnp.float8_e5m2
4297+
scale_dtype = jnp.float8_e8m0fnu
4298+
block_size = 64
4299+
swizzle = 128
4300+
k = 256
4301+
transforms = self.default_transforms(swizzle=swizzle, dtype=in_dtype)
4302+
out_transforms = self.default_transforms(dtype=jnp.float32)
4303+
4304+
def kernel(a_smem, b_smem, a_sparse_smem, a_scale_smem, b_scale_smem,
4305+
out_ref, barrier_ref, acc_tmem, a_sparse_tmem,
4306+
a_scale_tmem, b_scale_tmem):
4307+
plgpu.async_copy_sparse_metadata_to_tmem(a_sparse_smem, a_sparse_tmem)
4308+
plgpu.async_copy_scales_to_tmem(a_scale_smem, a_scale_tmem)
4309+
plgpu.async_copy_scales_to_tmem(b_scale_smem, b_scale_tmem)
4310+
plgpu.tcgen05_mma(acc_tmem,
4311+
a_smem,
4312+
plgpu.transpose_ref(b_smem, (1, 0)),
4313+
a_scale=a_scale_tmem,
4314+
b_scale=b_scale_tmem,
4315+
a_sparse_metadata=a_sparse_tmem,
4316+
accumulate=False)
4317+
plgpu.tcgen05_commit_arrive(barrier_ref)
4318+
plgpu.barrier_wait(barrier_ref)
4319+
out_ref[...] = plgpu.async_load_tmem(acc_tmem)
4320+
4321+
scratch_shapes = [
4322+
plgpu.Barrier(orders_tensor_core=True),
4323+
plgpu.TMEM((m, n), jnp.float32),
4324+
plgpu.TMEM((m, k // 2), jnp.uint2, layout=plgpu.TMEMLayout.SPARSE_METADATA_LAYOUT),
4325+
plgpu.TMEM((m, k // block_size), scale_dtype, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
4326+
plgpu.TMEM((n, k // block_size), scale_dtype, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
4327+
]
4328+
4329+
f = self.pallas_call(
4330+
kernel,
4331+
in_specs=(
4332+
plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
4333+
plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
4334+
plgpu.BlockSpec(memory_space=plgpu.SMEM),
4335+
plgpu.BlockSpec(memory_space=plgpu.SMEM),
4336+
plgpu.BlockSpec(memory_space=plgpu.SMEM),
4337+
),
4338+
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
4339+
out_specs=plgpu.BlockSpec(transforms=out_transforms),
4340+
scratch_shapes=scratch_shapes,
4341+
)
4342+
x = jax.random.uniform(jax.random.key(1), shape=(m, k // 2), dtype=jnp.float32).astype(in_dtype)
4343+
y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(in_dtype)
4344+
index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2)
4345+
valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]]
4346+
assert len(valid_pairs) == 6
4347+
x_pairs = jax.random.randint(jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8)
4348+
x_sparse = valid_pairs[x_pairs]
4349+
assert x_sparse.shape == (m, k // 4, 2)
4350+
ksx, ksy = jax.random.split(jax.random.key(5678), 2)
4351+
x_scale = jax.lax.bitcast_convert_type(
4352+
jax.random.randint(ksx, (m, k // block_size), 122, 132, dtype=jnp.uint8),
4353+
scale_dtype
4354+
)
4355+
y_scale = jax.lax.bitcast_convert_type(
4356+
jax.random.randint(ksy, (n, k // block_size), 122, 132, dtype=jnp.uint8),
4357+
scale_dtype
4358+
)
4359+
def format_scales(scales):
4360+
mn, k = scales.shape
4361+
assert mn % 128 == 0 and k % 4 == 0
4362+
return (
4363+
scales.reshape(mn // 128, 4, 32, k // 4, 4)
4364+
.transpose(0, 3, 2, 1, 4)
4365+
.reshape(mn // 128, k // 4, 32, 16)
4366+
)
4367+
def format_sparse_meta(meta):
4368+
mn, k, _2 = meta.shape
4369+
assert _2 == 2
4370+
k *= 2
4371+
return (
4372+
meta.reshape(mn // 128, 128, k // 64, 64)
4373+
.transpose(0, 2, 1, 3)
4374+
.astype(jnp.uint2)
4375+
)
4376+
z = f(
4377+
x, y,
4378+
format_sparse_meta(x_sparse),
4379+
format_scales(x_scale), format_scales(y_scale),
4380+
)
4381+
x_logical = np.zeros_like(x, shape=(m, k // 4, 4))
4382+
np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1)
4383+
x_logical = x_logical.reshape(m, k)
4384+
x32 = x_logical.astype(np.float32)
4385+
y32 = y.astype(np.float32)
4386+
a_logical_scales = jnp.repeat(x_scale, block_size, axis=1).astype(jnp.float32)
4387+
b_logical_scales = jnp.repeat(y_scale, block_size, axis=1).astype(jnp.float32)
4388+
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
4389+
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)
4390+
42904391
@parameterized.parameters(
42914392
(128, jnp.float16)
42924393
)

0 commit comments

Comments
 (0)