Skip to content

Commit 74e86ba

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for collective MMA in the Blackwell matmul example
PiperOrigin-RevId: 725630722
1 parent 21598d0 commit 74e86ba

File tree

1 file changed

+59
-44
lines changed

1 file changed

+59
-44
lines changed

jax/experimental/mosaic/gpu/examples/matmul_blackwell.py

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def build_kernel(
4242
tile_m: int = 128,
4343
tile_n: int = 128,
4444
max_concurrent_steps: int = 2,
45+
collective: bool = False,
4546
):
4647
i1 = ir.IntegerType.get_signless(1)
4748
i32 = ir.IntegerType.get_signless(32)
48-
f32 = ir.F32Type.get()
4949
index = ir.IndexType.get()
5050

5151
swizzle = 128
@@ -64,32 +64,46 @@ def build_kernel(
6464
tma_tile_m = 128
6565
tma_tile_kn = 64
6666

67+
block_tile_m = tile_m
68+
block_tile_n = tile_n
69+
if collective:
70+
tile_m *= 2
71+
tile_n *= 2
72+
6773
def kernel(ctx, a, b, d, smem):
6874
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
6975
(ab_full_barriers, ab_empty_barriers) = barriers
7076

7177
warp_idx = mgpu.warp_idx(sync=True)
72-
warp_leader = nvvm.elect_sync(i1)
78+
is_warp_leader = nvvm.elect_sync(i1)
7379

74-
is_warp = lambda i: arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32))
80+
is_leader_of = lambda i: arith.andi(arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader)
7581

76-
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m,index))
77-
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n,index))
82+
m_start = arith.muli(gpu.cluster_id(gpu.Dimension.x), c(tile_m,index))
83+
block_m_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(block_tile_m,index))
84+
n_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_n,index))
85+
is_leader_block = arith.cmpi(arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index))
7886

79-
with mgpu.when(arith.andi(is_warp(TMA_WARP), warp_leader)):
87+
with mgpu.when(is_leader_of(TMA_WARP)):
8088
@mgpu.fori(c(k_loop_iter, index), None)
8189
def _tma_body(ki, _):
8290
slot = arith.remui(ki, c(max_concurrent_steps, index))
8391
# TODO(apaszke): Use a predicate instead of a conditional.
8492
with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))):
8593
ab_empty_barriers[slot].wait()
8694
full_barrier = ab_full_barriers[slot]
87-
full_barrier.arrive_expect_tx(
88-
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
89-
)
95+
with mgpu.when(is_leader_block):
96+
full_barrier.arrive_expect_tx(
97+
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
98+
)
9099
k_start = arith.muli(ki, c(tile_k, index))
91100
common_args = dict(
92-
swizzle=swizzle, barrier=full_barrier, arrive=False, uniform=False,
101+
swizzle=swizzle,
102+
barrier=full_barrier,
103+
arrive=False,
104+
uniform=False,
105+
collective=gpu.Dimension.x,
106+
partitioned=0, # Non-contracting dim is always 0.
93107
)
94108
ctx.async_copy(
95109
src_ref=a,
@@ -109,66 +123,67 @@ def _tma_body(ki, _):
109123
**common_args,
110124
)
111125

112-
with mgpu.when(arith.andi(is_warp(MMA_WARP), warp_leader)):
113-
with mgpu.when(warp_leader):
114-
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
115-
def _mma_body(ki, accumulate):
116-
slot = arith.remui(ki, c(max_concurrent_steps, index))
117-
ab_full_barriers[slot].wait()
118-
tcgen05.mma(
119-
acc,
120-
mgpu.memref_slice(a_smem, slot),
121-
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
122-
a_swizzle=swizzle,
123-
b_swizzle=swizzle,
124-
accumulate=accumulate,
125-
)
126-
accumulate = arith.constant(i1, 1)
127-
is_last_iter = arith.cmpi(
128-
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
129-
)
130-
barrier_ptr = arith.select(
131-
is_last_iter,
132-
mma_done_barrier.get_ptr(),
133-
ab_empty_barriers[slot].get_ptr(),
134-
)
135-
tcgen05.commit_arrive(barrier_ptr)
136-
return accumulate
126+
with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)):
127+
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
128+
def _mma_body(ki, accumulate):
129+
slot = arith.remui(ki, c(max_concurrent_steps, index))
130+
ab_full_barriers[slot].wait()
131+
tcgen05.mma(
132+
acc,
133+
mgpu.memref_slice(a_smem, slot),
134+
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
135+
a_swizzle=swizzle,
136+
b_swizzle=swizzle,
137+
accumulate=accumulate,
138+
collective=collective,
139+
)
140+
accumulate = arith.constant(i1, 1)
141+
is_last_iter = arith.cmpi(
142+
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
143+
)
144+
barrier_ptr = arith.select(
145+
is_last_iter,
146+
mma_done_barrier.get_ptr(),
147+
ab_empty_barriers[slot].get_ptr(),
148+
)
149+
tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx)
150+
return accumulate
137151

138152
gpu.barrier()
139153
mma_done_barrier.wait(for_tensor_core=True)
140154

141155
acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
142156
mgpu.commit_shared()
143-
# TODO(apaszke): Free up TMEM?
144157
ctx.async_copy(
145158
src_ref=d_smem,
146159
dst_ref=d,
147-
gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)),
160+
gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)),
148161
gmem_transform=mgpu.TileTransform((128, 64)),
149162
swizzle=swizzle,
150163
)
164+
# TODO(apaszke): Free up TMEM?
151165
ctx.await_async_copy(0)
152166

153167
# TODO(apaszke): Use a union for output SMEM.
154168
smem = (
155-
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((tile_m, tile_k), (tma_tile_m, tma_tile_kn))), jnp.float16),
156-
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((tile_k, tile_n), (tma_tile_kn, tma_tile_kn))), jnp.float16),
157-
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
169+
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((block_tile_m, tile_k), (tma_tile_m, tma_tile_kn))), jnp.float16),
170+
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((tile_k, block_tile_n), (tma_tile_kn, tma_tile_kn))), jnp.float16),
171+
jax.ShapeDtypeStruct(mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
158172
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
159173
mgpu.Barrier(arrival_count=1),
160-
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D),
174+
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D, collective=collective),
161175
)
162176
return mgpu.as_gpu_kernel(
163177
kernel,
164-
(n // tile_n, m // tile_m, 1),
178+
(m // block_tile_m, n // tile_n, 1),
165179
(128, 1, 1),
166180
(
167181
jax.ShapeDtypeStruct((m, k), jnp.float16),
168182
jax.ShapeDtypeStruct((n, k), jnp.float16),
169183
),
170184
jax.ShapeDtypeStruct((m, n), jnp.float16),
171185
smem,
186+
cluster=(2 if collective else 1, 1, 1),
172187
)
173188

174189

@@ -188,8 +203,8 @@ def main(unused_argv):
188203
f = build_kernel(m, n, k, tile_m=m_tile, tile_n=n_tile)
189204
y = f(a, b).block_until_ready()
190205

191-
ref = np.asarray(a) @ np.asarray(b).T
192-
np.testing.assert_allclose(y, ref, atol=1e-3, rtol=1e-3)
206+
y_ref = jax.jit(lambda a, b: a @ b.T)(a, b)
207+
np.testing.assert_allclose(y, y_ref, atol=1e-3, rtol=1e-3)
193208
print("OK!")
194209

195210

0 commit comments

Comments
 (0)