Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ def mma(
f" accumulators, but got: {d.dtype}"
)
elif any(isinstance(element_type, t) for t in {ir.Float4E2M1FNType}):
if is_sparse:
raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn")
if not is_scaled:
raise ValueError(
f"MMA with element type {element_type} only supports block scaling"
Expand Down Expand Up @@ -426,10 +424,13 @@ def mma(
a_sparse_metadata = cast(TMEMRef, a_sparse_metadata)
if n % 32:
raise ValueError(f"Sparse MMA requires N to be divisible by 32, got: {n}")
if a_sparse_metadata.shape != (m, k // 2):
sparse_group_elems = 8 if utils.bitwidth(element_type) == 4 else 4
# Each sparse group has 2 entries.
expected_meta_k = k // sparse_group_elems * 2
if a_sparse_metadata.shape != (m, expected_meta_k):
raise ValueError(
f"A sparse metadata shape mismatch: expected {(m, k // 2)}, got"
f" {a_sparse_metadata.shape}"
f"A sparse metadata shape mismatch: expected {(m, expected_meta_k)},"
f" got {a_sparse_metadata.shape}"
)
if a_sparse_metadata.dtype != ir.IntegerType.get_signless(2):
raise ValueError(
Expand Down Expand Up @@ -508,8 +509,9 @@ def mma(
if a_sparse_addr_base is not None:
if n_groups != 1 or m_groups != 1:
raise NotImplementedError("A sparse metadata address calculation for multiple tiles")
assert k_group_elems % 32 == 0
cols_per_k_group = k_group_elems // 32
sparse_group_elems = 8 if utils.bitwidth(mma_element_type) == 4 else 4
# Each sparse group has 2 entries, each TMEM column holds 16 i2 entries.
cols_per_k_group = k_group_elems // sparse_group_elems * 2 // 16
a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32))
else:
a_sparse_addr = None
Expand Down Expand Up @@ -611,11 +613,11 @@ def _do_mma(
sparse=is_sparse,
)
elif isinstance(element_type, ir.Float4E2M1FNType):
assert not is_sparse
assert not a_transpose and not b_transpose
create_scaled_instr_descriptor = functools.partial(
create_scaled_f4_instr_descriptor,
scale_type=scale_element_type,
sparse=is_sparse,
)
if scale_element_type == ir.Float8E8M0FNUType.get():
kind = "mxf4.block_scale.scale_vec::2X"
Expand Down Expand Up @@ -674,18 +676,14 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
return offset >> 4
for k_step in range(k // instr_k):
if is_sparse:
assert 32 <= instr_k <= 64
selector_width = instr_k
k_steps_for_col_inc = 64 // selector_width
assert (k // instr_k) % k_steps_for_col_inc == 0
sp_selector = k_step % k_steps_for_col_inc
# If the K group is large, we need to increment the sparse metadata.
# TODO(apaszke): At this point the purpose of this function is becoming
# less clear, since we end up replicating address arithmetic that's
# already there in the caller. We should unify them into a single loop.
sparse_group_elems = 8 if elem_bitwidth == 4 else 4
# Each sparse group has 2 entries, each TMEM column holds 16 i2 entries.
meta_cols_per_instr = instr_k // sparse_group_elems * 2 // 16
instrs_per_col_pair = 2 // meta_cols_per_instr
sp_selector = k_step % instrs_per_col_pair
sparse_addr = (
arith.addi(
a_sparse_addr, utils.c(k_step // k_steps_for_col_inc * 2, i32)
a_sparse_addr, utils.c(k_step // instrs_per_col_pair * 2, i32)
),
)
if is_scaled:
Expand Down
120 changes: 120 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,126 @@ def format_scales(scales):
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)

@parameterized.product(
in_jax_dtype=(jnp.float4_e2m1fn,),
scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn),
m=(128,),
n=(128, 256),
swizzle=(128,),
)
def test_mma_block_scaled_sparse_f4(self, m, n, in_jax_dtype, scale_jax_dtype, swizzle):
out_jax_dtype = jnp.float32
sparse_meta_dtype = jnp.uint2
if scale_jax_dtype == jnp.float8_e8m0fnu:
block_size = 64
elif scale_jax_dtype == jnp.float8_e4m3fn:
block_size = 32
k_steps = 2

in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
swizzle_elems = 8 * swizzle // bitwidth(in_mlir_dtype)
k = swizzle_elems * k_steps
lhs_tiling = rhs_tiling = (8, swizzle_elems)

def kernel(ctx, lhs, rhs, lhs_sparse_gmem, lhs_scales_gmem, rhs_scales_gmem, out, scratch):
(
lhs_smem, rhs_smem, lhs_sparse_smem,
lhs_scales_smem, rhs_scales_smem,
barriers, mma_barrier, acc, lhs_sparse, lhs_scales, rhs_scales,
) = scratch
operand_kwargs = dict(
swizzle=swizzle,
gmem_transform=mgpu.TileTransform(lhs_tiling),
)
ctx.async_copy(src_ref=lhs, dst_ref=lhs_smem, barrier=barriers[0], **operand_kwargs)
ctx.async_copy(src_ref=rhs, dst_ref=rhs_smem, barrier=barriers[1], swizzle=swizzle, gmem_transform=mgpu.TileTransform(rhs_tiling))
ctx.async_copy(src_ref=lhs_sparse_gmem, dst_ref=lhs_sparse_smem, barrier=barriers[2])
ctx.async_copy(src_ref=lhs_scales_gmem, dst_ref=lhs_scales_smem, barrier=barriers[3])
ctx.async_copy(src_ref=rhs_scales_gmem, dst_ref=rhs_scales_smem, barrier=barriers[4])
for i in range(5):
barriers[i].wait()
with mgpu.single_thread():
tcgen05.async_copy_sparse_metadata_smem_to_tmem(lhs_sparse_smem, lhs_sparse)
tcgen05.async_copy_scales_smem_to_tmem(lhs_scales_smem, lhs_scales)
tcgen05.async_copy_scales_smem_to_tmem(rhs_scales_smem, rhs_scales)
tcgen05.mma(
acc,
lhs_smem,
mgpu.memref_transpose(rhs_smem, (1, 0, 3, 2)),
a_swizzle=swizzle,
b_swizzle=swizzle,
a_scale=lhs_scales,
b_scale=rhs_scales,
a_sparse_metadata=lhs_sparse,
accumulate=False,
)
tcgen05.commit_arrive(mma_barrier)
mma_barrier.wait(orders_tensor_core=True)
acc.load().store_untiled(out, optimized=False)

x_shape = (m, k // 2)
x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype)
y_shape = (n, k)
y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
meta_k = k // 4
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype),
jax.ShapeDtypeStruct((m // 128, meta_k // 64, 128, 64), sparse_meta_dtype),
jax.ShapeDtypeStruct((m // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
jax.ShapeDtypeStruct((n // 128, k // (block_size * 4), 32, 16), scale_jax_dtype),
mgpu.TMABarrier(5),
mgpu.Barrier(1),
mgpu.TMEM((m, n), out_jax_dtype),
mgpu.TMEM((m, meta_k), sparse_meta_dtype, layout=tcgen05.sparse_meta_layout()),
mgpu.TMEM((m, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
mgpu.TMEM((n, k // block_size), scale_jax_dtype, layout=tcgen05.scales_layout()),
]
n_groups = k // 8
index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2)
valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]]
assert len(valid_pairs) == 6
x_pairs = jax.random.randint(jax.random.key(1234), (m, n_groups), 0, 6, dtype=jnp.uint8)
x_sparse = valid_pairs[x_pairs]
assert x_sparse.shape == (m, n_groups, 2)
def format_sparse_meta(meta):
mn, groups, _2 = meta.shape
assert _2 == 2
k_meta = groups * 2
meta_tiled = (
meta.reshape(mn // 128, 128, k_meta // 64, 64).transpose(0, 2, 1, 3)
)
return (
meta_tiled.reshape(mn // 128, k_meta // 64, 128, 64)
.astype(sparse_meta_dtype)
)
x_gpu_sparse = format_sparse_meta(x_sparse)
a_scales, b_scales = self._sample_scales(m, k, n, block_size, scale_jax_dtype)
def format_scales(scales):
mn, k = scales.shape
assert mn % 128 == 0 and k % 4 == 0, scales.shape
return (
scales.reshape(mn // 128, 4, 32, k // 4, 4)
.transpose(0, 3, 2, 1, 4)
.reshape(mn // 128, k // 4, 32, 16)
)
a_gpu_scales, b_gpu_scales = map(format_scales, (a_scales, b_scales))
args = (x, y, x_gpu_sparse, a_gpu_scales, b_gpu_scales)
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), args, out_shape, scratch_shape
)(*args)
# 4-bit sparse data is filled in pairs of elements.
x32 = x.astype(np.float32).reshape(m, n_groups, 2, 2)
x_logical32 = np.zeros((m, n_groups, 4, 2), dtype=np.float32)
np.put_along_axis(x_logical32, x_sparse[..., np.newaxis], x32, axis=-2)
x_logical32 = x_logical32.reshape(m, k)
y32 = y.astype(np.float32)
a_logical_scales = jnp.repeat(a_scales, block_size, axis=1).astype(jnp.float32)
b_logical_scales = jnp.repeat(b_scales, block_size, axis=1).astype(jnp.float32)
ref = (x_logical32 * a_logical_scales) @ (y32 * b_logical_scales).T
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)

@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
Expand Down
101 changes: 101 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,107 @@ def kernel(a_smem, b_smem, a_sparse_smem, out_ref,
ref = x_logical.astype(jnp.float32) @ y.T.astype(jnp.float32)
np.testing.assert_allclose(z, ref, atol=7e-5, rtol=5e-6)

@parameterized.product(
m=[128],
n=[128, 256],
)
def test_block_scaled_sparse_matmul(self, m, n):
self.skip_if_wg_semantics()
in_dtype = jnp.float8_e5m2
scale_dtype = jnp.float8_e8m0fnu
block_size = 64
swizzle = 128
k = 256
transforms = self.default_transforms(swizzle=swizzle, dtype=in_dtype)
out_transforms = self.default_transforms(dtype=jnp.float32)

def kernel(a_smem, b_smem, a_sparse_smem, a_scale_smem, b_scale_smem,
out_ref, barrier_ref, acc_tmem, a_sparse_tmem,
a_scale_tmem, b_scale_tmem):
plgpu.async_copy_sparse_metadata_to_tmem(a_sparse_smem, a_sparse_tmem)
plgpu.async_copy_scales_to_tmem(a_scale_smem, a_scale_tmem)
plgpu.async_copy_scales_to_tmem(b_scale_smem, b_scale_tmem)
plgpu.tcgen05_mma(acc_tmem,
a_smem,
plgpu.transpose_ref(b_smem, (1, 0)),
a_scale=a_scale_tmem,
b_scale=b_scale_tmem,
a_sparse_metadata=a_sparse_tmem,
accumulate=False)
plgpu.tcgen05_commit_arrive(barrier_ref)
plgpu.barrier_wait(barrier_ref)
out_ref[...] = plgpu.async_load_tmem(acc_tmem)

scratch_shapes = [
plgpu.Barrier(orders_tensor_core=True),
plgpu.TMEM((m, n), jnp.float32),
plgpu.TMEM((m, k // 2), jnp.uint2, layout=plgpu.TMEMLayout.SPARSE_METADATA_LAYOUT),
plgpu.TMEM((m, k // block_size), scale_dtype, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
plgpu.TMEM((n, k // block_size), scale_dtype, layout=plgpu.TMEMLayout.SCALES_LAYOUT),
]

f = self.pallas_call(
kernel,
in_specs=(
plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
plgpu.BlockSpec(memory_space=plgpu.SMEM),
plgpu.BlockSpec(memory_space=plgpu.SMEM),
plgpu.BlockSpec(memory_space=plgpu.SMEM),
),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
out_specs=plgpu.BlockSpec(transforms=out_transforms),
scratch_shapes=scratch_shapes,
)
x = jax.random.uniform(jax.random.key(1), shape=(m, k // 2), dtype=jnp.float32).astype(in_dtype)
y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(in_dtype)
index_pairs = np.asarray(np.meshgrid(range(4), range(4))).T.reshape(-1, 2)
valid_pairs = index_pairs[index_pairs[:, 0] < index_pairs[:, 1]]
assert len(valid_pairs) == 6
x_pairs = jax.random.randint(jax.random.key(1234), (m, k // 4), 0, 6, dtype=jnp.uint8)
x_sparse = valid_pairs[x_pairs]
assert x_sparse.shape == (m, k // 4, 2)
ksx, ksy = jax.random.split(jax.random.key(5678), 2)
x_scale = jax.lax.bitcast_convert_type(
jax.random.randint(ksx, (m, k // block_size), 122, 132, dtype=jnp.uint8),
scale_dtype
)
y_scale = jax.lax.bitcast_convert_type(
jax.random.randint(ksy, (n, k // block_size), 122, 132, dtype=jnp.uint8),
scale_dtype
)
def format_scales(scales):
mn, k = scales.shape
assert mn % 128 == 0 and k % 4 == 0
return (
scales.reshape(mn // 128, 4, 32, k // 4, 4)
.transpose(0, 3, 2, 1, 4)
.reshape(mn // 128, k // 4, 32, 16)
)
def format_sparse_meta(meta):
mn, k, _2 = meta.shape
assert _2 == 2
k *= 2
return (
meta.reshape(mn // 128, 128, k // 64, 64)
.transpose(0, 2, 1, 3)
.astype(jnp.uint2)
)
z = f(
x, y,
format_sparse_meta(x_sparse),
format_scales(x_scale), format_scales(y_scale),
)
x_logical = np.zeros_like(x, shape=(m, k // 4, 4))
np.put_along_axis(x_logical, x_sparse, x.reshape(x_sparse.shape), axis=-1)
x_logical = x_logical.reshape(m, k)
x32 = x_logical.astype(np.float32)
y32 = y.astype(np.float32)
a_logical_scales = jnp.repeat(x_scale, block_size, axis=1).astype(jnp.float32)
b_logical_scales = jnp.repeat(y_scale, block_size, axis=1).astype(jnp.float32)
ref = (x32 * a_logical_scales) @ (y32 * b_logical_scales).T
np.testing.assert_allclose(z, ref, atol=2e-4, rtol=5e-6)

@parameterized.parameters(
(128, jnp.float16)
)
Expand Down
Loading