Skip to content

Commit 1c1e2e6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for stores to TMEM
We can support reading and writing of both 32- and 16-bit types now. PiperOrigin-RevId: 741487690
1 parent 3045147 commit 1c1e2e6

File tree

2 files changed

+212
-46
lines changed

2 files changed

+212
-46
lines changed

jax/experimental/mosaic/gpu/tcgen05.py

Lines changed: 175 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -327,24 +327,30 @@ def tmem_relinquish_alloc_permit():
327327
has_side_effects=True,
328328
)
329329

330-
def tmem_load(tmem_addr, shape, num, packing: int = 1):
330+
def _tmem_access_helper(shape, num, packing: int = 1):
331331
if num.bit_count() != 1 or num > 128:
332332
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
333333
match shape:
334334
case "16x128b":
335-
num_out_regs = 2
335+
num_regs = 2
336336
case "16x256b":
337-
num_out_regs = 4
337+
num_regs = 4
338338
case _:
339339
raise NotImplementedError(f"{shape=} is unsupported")
340-
if num * num_out_regs >= 256:
340+
num_regs *= num
341+
if num_regs > 255:
341342
raise ValueError(
342-
f"Loading too much TMEM at once: {num=} and each load requires"
343-
f" {num_out_regs} registers, which exceeds the limit of 256"
343+
f"TMEM transation too big : {shape=} and {num=} involve"
344+
f" {num_regs} registers per-thread, which exceeds the limit of 255"
344345
)
345-
num_out_regs *= num
346+
regs_vector = ",".join(f"${i}" for i in range(num_regs))
347+
regs_vector = "{" + regs_vector + "}"
348+
return num_regs, regs_vector
349+
350+
351+
def tmem_load(tmem_addr, shape, num, packing: int = 1):
346352
i32 = ir.IntegerType.get_signless(32)
347-
out_regs = ",".join("$" + str(i) for i in range(num_out_regs))
353+
num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing)
348354
if packing == 1:
349355
pack_mod = ""
350356
elif packing == 2:
@@ -356,13 +362,30 @@ def tmem_load(tmem_addr, shape, num, packing: int = 1):
356362
"!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
357363
),
358364
[tmem_addr],
359-
f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {{{out_regs}}}, [${num_out_regs}];",
365+
f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];",
360366
"=r," * num_out_regs + "r",
361367
has_side_effects=True,
362368
)
363369
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
364370

365371

372+
def tmem_store(tmem_addr, shape, num, regs, packing: int = 1):
373+
num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing)
374+
if packing == 1:
375+
pack_mod = ""
376+
elif packing == 2:
377+
pack_mod = ".unpack::16b"
378+
else:
379+
raise ValueError(f"Unsupported packing: {packing}")
380+
llvm.inline_asm(
381+
ir.Type.parse("!llvm.void"),
382+
[*regs, tmem_addr],
383+
f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};",
384+
"r," * num_out_regs + "r",
385+
has_side_effects=True,
386+
)
387+
388+
366389
@dataclasses.dataclass(frozen=True)
367390
class TMEMLayout:
368391
"""Represents the way a shape is laid out in TMEM.
@@ -562,62 +585,168 @@ def __getitem__(self, *idxs):
562585
)
563586
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
564587

588+
def __setitem__(self, idxs, value):
589+
if not isinstance(idxs, tuple):
590+
idxs = (idxs,)
591+
base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
592+
if any(is_squeezed):
593+
raise ValueError(
594+
"TMEM stores don't support integer indexing (only slices allowed)"
595+
)
596+
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
597+
raise NotImplementedError("Slicing parts of TMEM not implemented yet")
598+
if self.shape[1] % 8:
599+
raise NotImplementedError
600+
if utils.bitwidth(self.dtype) not in {16, 32}:
601+
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
602+
if not isinstance(value, fa.FragmentedArray):
603+
raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}")
604+
if value.shape != self.shape:
605+
raise ValueError(
606+
f"Stored array has shape {value.shape}, but TMEM has shape"
607+
f" {self.shape}"
608+
)
609+
if value.mlir_dtype != self.dtype:
610+
raise ValueError(
611+
f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype"
612+
f" {self.dtype}"
613+
)
614+
if value.layout != LAYOUT:
615+
raise ValueError(
616+
f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is"
617+
" supported"
618+
)
619+
if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
620+
# store_32xcols needs a 4xN array, but the FA tiling we use here tiles
621+
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
622+
_store_32xcols(
623+
self.address, value.registers.T.reshape((4, -1))
624+
)
625+
else: # TODO(apaszke): Collective MMA layout
626+
raise NotImplementedError(
627+
f"Stores only implemented for refs with standard layout, got: {self.layout}"
628+
)
629+
630+
631+
def _transfer_32xcols(base_addr, cols):
632+
i32 = ir.IntegerType.get_signless(32)
633+
cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT.
634+
assert cols % cols_per_num == 0
635+
total_num = cols // cols_per_num
636+
if total_num <= 32:
637+
instr_num = total_num
638+
elif total_num == 64:
639+
instr_num = 32
640+
else:
641+
raise NotImplementedError(total_num)
642+
# We transfer 16 lanes at a time, but have 32 to deal with.
643+
for lane_step in range(2):
644+
addr_row = arith.addi(base_addr, utils.c((lane_step * 16) << 16, i32))
645+
cols_per_instr = instr_num * cols_per_num
646+
for num_step in range(total_num // instr_num):
647+
num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num)
648+
addr_row_col = arith.addi(addr_row, utils.c(num_step * cols_per_instr, i32))
649+
yield addr_row_col, instr_num, lane_step, num_slice
650+
651+
652+
def _store_32xcols(base_addr, vector_regs):
653+
i32 = ir.IntegerType.get_signless(32)
654+
assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4
655+
cols = vector_regs.shape[1] * 8
656+
657+
packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
658+
if packing == 1:
659+
store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
660+
regs = np.empty((4, vector_regs.shape[1], 2), dtype=object)
661+
c0 = arith.constant(i32, 0)
662+
c1 = arith.constant(i32, 1)
663+
for idx, vreg in np.ndenumerate(vector_regs):
664+
regs[(*idx, 0)] = llvm.extractelement(vreg, c0)
665+
regs[(*idx, 1)] = llvm.extractelement(vreg, c1)
666+
regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2)
667+
# From a single lane perspective a num tile consists of a 2x2, with the
668+
# minor dim traversing columns and major being 8 rows apart.
669+
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
670+
assert regs.shape[-2:] == (2, 2)
671+
elif packing == 2:
672+
store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
673+
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
674+
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
675+
regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2)
676+
else:
677+
raise NotImplementedError(packing)
678+
679+
it = _transfer_32xcols(base_addr, cols)
680+
for addr_row_col, instr_num, lane_step, num_slice in it:
681+
regs_slice = regs[lane_step, num_slice].flat
682+
tmem_store(addr_row_col, store_shape, instr_num, regs_slice, packing)
683+
565684

566685
def _load_32xcols(base_addr, cols, dtype):
567-
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
568686
i32 = ir.IntegerType.get_signless(32)
687+
vec_ty = ir.VectorType.get((2,), dtype)
569688
packing = 32 // utils.bitwidth(dtype)
570689
if packing == 1:
571-
load_shape = "16x256b" # 8 columns * 32 bits = 256 bits
572-
cols_per_num_tile = 8 * packing
690+
load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
573691
elif packing == 2:
574-
load_shape = "16x128b" # 8 columns * 16 bits = 128 bits
575-
cols_per_num_tile = 4 * packing
692+
load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
576693
else:
577694
raise NotImplementedError(packing)
578-
assert cols % cols_per_num_tile == 0
579-
num = cols // cols_per_num_tile
580-
if num <= 32:
581-
num_tiling = num
582-
elif num == 64:
583-
num_tiling = 32
584-
else:
585-
raise NotImplementedError(num)
695+
586696
vector_regs = np.ndarray((4, cols // 8), dtype=object)
587-
# We load 16 lanes at a time, but need 32 in total.
588-
for row_group in range(2):
589-
addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16))
590-
regs = []
591-
for num_group in range(num // num_tiling):
592-
addr_row_col = arith.addi(
593-
addr_row,
594-
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
595-
)
596-
regs += tmem_load(addr_row_col, load_shape, num_tiling, packing)
697+
698+
it = _transfer_32xcols(base_addr, cols)
699+
c0 = arith.constant(i32, 0)
700+
c1 = arith.constant(i32, 1)
701+
for addr_row_col, instr_num, lane_step, num_slice in it:
702+
regs = tmem_load(addr_row_col, load_shape, instr_num, packing)
703+
row_slice = slice(lane_step * 2, (lane_step + 1) * 2)
704+
# This aliases the original array, so updates will be reflected there.
705+
vector_regs_update = vector_regs[row_slice, num_slice]
706+
assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num)
597707
if packing == 1:
598708
regs = [llvm.bitcast(dtype, r) for r in regs]
599-
undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
600-
for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(cols // 8, 2), strict=True):
601-
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
602-
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
603-
vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg
709+
# From a single lane perspective a num tile consists of a 2x2, with the
710+
# minor dim traversing columns and major being 8 rows apart.
711+
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
712+
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1)
713+
undef = llvm.mlir_undef(vec_ty)
714+
assert regs.shape == (*vector_regs_update.shape, 2)
715+
for idx in np.ndindex(vector_regs_update.shape):
716+
high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0)
717+
vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1)
718+
vector_regs_update[idx] = vreg
604719
else:
605720
assert packing == 2
606-
regs = [llvm.bitcast(ir.VectorType.get((2,), dtype), r) for r in regs]
607-
for vreg, idx in zip(regs, np.ndindex(cols // 8, 2), strict=True):
608-
vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg
721+
regs = [llvm.bitcast(vec_ty, r) for r in regs]
722+
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
723+
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
724+
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1)
725+
vector_regs_update[...] = regs
726+
609727
return vector_regs
610728

611729

612-
# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN.
613730
def _m128_layout(shape: tuple[int, ...]):
614731
if len(shape) != 2:
615732
raise ValueError(f"Shape {shape} is not 2D")
616733
if shape[0] % 128 != 0 or shape[1] % 8 != 0:
617734
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
618-
return fa.TiledLayout(
619-
fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))),
620-
warp_dim=-8,
621-
lane_dims=(-4, -3),
622-
vector_dim=-1,
735+
return LAYOUT
736+
737+
# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN.
738+
# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT)
739+
LAYOUT = fa.TiledLayout(
740+
fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))),
741+
warp_dim=-8,
742+
lane_dims=(-4, -3),
743+
vector_dim=-1,
744+
)
745+
746+
747+
def commit_tmem():
748+
void = ir.Type.parse("!llvm.void")
749+
llvm.inline_asm(
750+
void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True,
623751
)
752+
utils.warpgroup_barrier()

tests/mosaic/gpu_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,43 @@ def setUp(self):
908908
if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities):
909909
self.skipTest("Only works on GPU with capability sm_100a or sm_101a")
910910

911+
@parameterized.parameters([jnp.float32, jnp.float16])
912+
def test_load_store_tmem(self, jax_dtype):
913+
swizzle = 128
914+
in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
915+
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
916+
tiling = (8, swizzle_elems)
917+
918+
def kernel(ctx, input, output, scratch):
919+
smem, barrier, tmem = scratch
920+
ctx.async_copy(
921+
src_ref=input,
922+
dst_ref=smem,
923+
swizzle=swizzle,
924+
gmem_transform=mgpu.TileTransform(tiling),
925+
barrier=barrier,
926+
)
927+
barrier.wait()
928+
tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT)
929+
tcgen05.commit_tmem()
930+
tmem[:].store_tiled(smem, swizzle)
931+
mgpu.commit_shared()
932+
ctx.async_copy(
933+
src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling),
934+
)
935+
ctx.await_async_copy(0)
936+
937+
x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype)
938+
scratch_shape = [
939+
jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype),
940+
mgpu.TMABarrier(),
941+
mgpu.TMEM(x.shape, jax_dtype),
942+
]
943+
y = mgpu.as_gpu_kernel(
944+
kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape
945+
)(x)
946+
np.testing.assert_array_equal(x, y)
947+
911948
@parameterized.product(
912949
lhs_transpose=(False, True),
913950
rhs_transpose=(False, True),

0 commit comments

Comments
 (0)