Skip to content

Commit 21598d0

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for non-multicast .cta_group::2 async_copies
This instruction is particularly useful for collective MMA, since it lets us easily report on the progress of async copies from both blocks in the single block that will be performing the MMA. PiperOrigin-RevId: 725618793
1 parent 2c165bf commit 21598d0

File tree

2 files changed

+126
-31
lines changed

2 files changed

+126
-31
lines changed

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 119 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ def named_region(self, *args, **kwargs):
252252
else:
253253
yield
254254

255+
def cluster_idx(
256+
self, dim: gpu.Dimension | Sequence[gpu.Dimension] | None = None
257+
) -> ir.Value:
258+
"""Returns the index of a block within a subset of the cluster spanned by the given dimensions."""
259+
if dim is None:
260+
dim = gpu.Dimension
261+
elif isinstance(dim, gpu.Dimension):
262+
dim = (dim,)
263+
index = ir.IndexType.get()
264+
stride = 1
265+
idx = c(0, index)
266+
for d in sorted(dim):
267+
if self.cluster_size[d] == 1: # Optimize a multiply by 0.
268+
continue
269+
idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)))
270+
stride *= self.cluster_size[d]
271+
return idx
272+
255273
def _alloc_scratch(
256274
self,
257275
size: int,
@@ -355,8 +373,35 @@ def async_copy(
355373
arrive: bool | None = None,
356374
uniform: bool = True,
357375
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
376+
partitioned: int | None = None,
358377
predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG.
359378
):
379+
"""Initiates an async copy between GMEM and SMEM.
380+
381+
Exactly one of `src_ref` and `dst_ref` must be in GMEM and in SMEM, and the
382+
SMEM reference must be contiguous. The GMEM window that is read or written
383+
to is specified by the `gmem_slice`. The copy can change the order in which
384+
the data appears in the window by applying a sequence of transforms to the
385+
GMEM reference (as specified by `gmem_transform`).
386+
387+
When `collective` is specified (only allowed for GMEM -> SMEM copies), the
388+
identical async_copy must be scheduled by all blocks that share the same
389+
coordinates along collective dimensions within a cluster. The behavior is
390+
undefined otherwise. The semantics of collective loads depend further on the
391+
`partitioned` argument:
392+
393+
- If `partitioned` is not specified, all blocks load the same data into
394+
their shared memory and all receive the update in their barriers, unless
395+
`arrive` is False. If `arrive` is False, you should expect the barrier to
396+
have expect_tx incremented by the same amount of bytes as if `collective`
397+
was not specified.
398+
- If `partitioned` is specified, each block only loads a separate slice of
399+
the data into SMEM, partitioned into equal tiles along the `partitioned`
400+
dimension. In this case only the barrier of the first block in the
401+
collective will have its expect_tx incremented by the total size of the
402+
transfer across all blocks involved in the collective. Barriers supplied
403+
by other blocks will be ignored (even if `arrive` is True).
404+
"""
360405
index = ir.IndexType.get()
361406
i16 = ir.IntegerType.get_signless(16)
362407
i32 = ir.IntegerType.get_signless(32)
@@ -408,13 +453,46 @@ def async_copy(
408453
" multiple of 16 bytes"
409454
)
410455

411-
# TMA supports OOB indices, so we skip the check.
456+
# NOTE: TMA supports OOB indices, so we skip the check.
412457
base_indices, slice_shape, is_squeezed = utils.parse_indices(
413458
gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False
414459
)
415460
dyn_base_indices = tuple(
416461
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
417462
)
463+
del base_indices # Use the dynamic indices from now on!
464+
465+
collective_size = 1
466+
if collective is not None:
467+
if isinstance(collective, gpu.Dimension):
468+
collective = (collective,)
469+
collective_size = math.prod(self.cluster_size[d] for d in collective)
470+
if gmem_ref is dst_ref:
471+
raise ValueError("Only GMEM -> SMEM copies can be collective")
472+
if partitioned is not None:
473+
if collective is None:
474+
raise ValueError("Only collective loads can be partitioned")
475+
if collective_size > 1 and partitioned is not None:
476+
if math.prod(self.cluster_size) != 2:
477+
raise NotImplementedError(
478+
"Partitioned loads only supported for clusters of size 2"
479+
)
480+
if slice_shape[partitioned] % collective_size != 0:
481+
raise ValueError(
482+
f"The collective size ({collective_size}) must divide the slice"
483+
" shape along the partitioned dimension, but it has size"
484+
f" {slice_shape[partitioned]}"
485+
)
486+
slice_shape[partitioned] //= collective_size
487+
dyn_base_indices = list(dyn_base_indices)
488+
dyn_base_indices[partitioned] = arith.addi(
489+
dyn_base_indices[partitioned],
490+
arith.muli(
491+
self.cluster_idx(collective), c(slice_shape[partitioned], index)
492+
),
493+
)
494+
dyn_base_indices = tuple(dyn_base_indices)
495+
418496
squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed]
419497
sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed]
420498
# Indexing is really slicing + squeezing, and user transforms are meant to
@@ -472,12 +550,9 @@ def async_copy(
472550
dyn_base_indices = list(dyn_base_indices)
473551
slice_shape = list(slice_shape)
474552
assert all(d == 1 for d in slice_shape[:num_squeezed_dims])
475-
collective_size = 1
476-
if collective is not None:
477-
if isinstance(collective, gpu.Dimension):
478-
collective = (collective,)
479-
collective_size = math.prod(self.cluster_size[d] for d in collective)
480-
if collective_size > 1:
553+
554+
# Partitioned loads have already been processed (before transforms).
555+
if collective_size > 1 and partitioned is None:
481556
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
482557
# No need to partition squeezed dims. They don't even exist in smem_ref.
483558
assert dim >= num_squeezed_dims
@@ -490,13 +565,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
490565
(slice(None),) * (dim - num_squeezed_dims)
491566
+ (utils.ds(block_offset, slice_shape[dim]),),
492567
)
493-
stride = 1
494-
idx = c(0, index)
495-
for d in sorted(collective):
496-
if self.cluster_size[d] == 1: # Optimize a multiply by 0.
497-
continue
498-
idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)))
499-
stride *= self.cluster_size[d]
568+
idx = self.cluster_idx(collective)
500569
rem_collective_size = collective_size
501570
for dim, slice_size in enumerate(slice_shape[:-1]):
502571
if slice_size % rem_collective_size == 0:
@@ -572,15 +641,44 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
572641
)
573642
barrier_ptr = barrier.get_ptr()
574643
with uniform_ctx():
575-
if arrive:
576-
nvvm.mbarrier_arrive_expect_tx_shared(
577-
barrier_ptr, transfer_bytes, predicate=predicate
644+
if collective_size > 1 and partitioned is not None:
645+
if predicate is None:
646+
predicate = c(1, ir.IntegerType.get_signless(1))
647+
if arrive:
648+
first_block = arith.cmpi(
649+
arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index),
650+
)
651+
arrive_predicate = arith.andi(predicate, first_block)
652+
nvvm.mbarrier_arrive_expect_tx_shared(
653+
barrier_ptr, transfer_bytes, predicate=arrive_predicate
654+
)
655+
rank = len(slice_shape)
656+
idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank))
657+
llvm.inline_asm(
658+
ir.Type.parse("!llvm.void"),
659+
[predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices],
660+
f"""
661+
{{
662+
.reg .b32 mapped_addr;
663+
@$0 mapa.shared::cluster.u32 mapped_addr, $3, 0;
664+
@$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2
665+
[$1], [$2, {{{idx_operands}}}], [mapped_addr];
666+
}}
667+
""",
668+
"b,r,l,r" + ",r" * rank,
669+
has_side_effects=True,
670+
)
671+
else:
672+
if arrive:
673+
nvvm.mbarrier_arrive_expect_tx_shared(
674+
barrier_ptr, transfer_bytes, predicate=predicate
675+
)
676+
nvvm.cp_async_bulk_tensor_shared_cluster_global(
677+
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [],
678+
multicast_mask=multicast_mask, predicate=predicate
578679
)
579-
nvvm.cp_async_bulk_tensor_shared_cluster_global(
580-
smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [],
581-
multicast_mask=multicast_mask, predicate=predicate
582-
)
583680
else:
681+
assert multicast_mask is None
584682
with uniform_ctx():
585683
nvvm.cp_async_bulk_tensor_global_shared_cta(
586684
tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate

tests/mosaic/gpu_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,33 +1082,29 @@ def kernel(ctx, lhs, rhs, out, scratch):
10821082
if rhs_transpose:
10831083
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
10841084
block_id = gpu.cluster_block_id(gpu.Dimension.x)
1085-
m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile)
1086-
n_slice = ds(arith.muli(block_id, c(n_block_tile, index)), n_block_tile)
1087-
# TODO(apaszke): Add support for collective partitioned loads.
10881085
ctx.async_copy(
10891086
src_ref=lhs,
10901087
dst_ref=lhs_smem,
10911088
swizzle=swizzle,
1092-
gmem_slice=m_slice,
10931089
gmem_transform=lhs_transform,
10941090
barrier=barriers[0],
1091+
collective=gpu.Dimension.x,
1092+
partitioned=1 if lhs_transpose else 0, # Split non-contracting dim.
10951093
)
10961094
ctx.async_copy(
10971095
src_ref=rhs,
10981096
dst_ref=rhs_smem,
10991097
swizzle=swizzle,
1100-
gmem_slice=n_slice,
11011098
gmem_transform=rhs_transform,
11021099
barrier=barriers[1],
1100+
collective=gpu.Dimension.x,
1101+
partitioned=0 if rhs_transpose else 1, # Split non-contracting dim.
11031102
)
1104-
barriers[0].wait()
1105-
barriers[1].wait()
1106-
# Make sure both blocks have loaded their data.
1107-
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
1108-
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
11091103
is_leader_thread = single_thread_predicate()
11101104
is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index))
11111105
with when(arith.andi(is_first_block, is_leader_thread)):
1106+
barriers[0].wait()
1107+
barriers[1].wait()
11121108
if lhs_transpose:
11131109
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
11141110
if rhs_transpose:
@@ -1118,6 +1114,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
11181114
)
11191115
tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx)
11201116
barriers[2].wait(for_tensor_core=True)
1117+
m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile)
11211118
acc[:].store_untiled(memref_slice(out, m_slice))
11221119

11231120
in_finfo = jnp.finfo(in_jax_dtype)

0 commit comments

Comments
 (0)