Skip to content

Commit b926fac

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Simplify load/store methods now that we have fewer layouts
PiperOrigin-RevId: 745139008
1 parent d6524dc commit b926fac

File tree

3 files changed

+50
-102
lines changed

3 files changed

+50
-102
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def _swap_lowering_rule(
12261226
is_signed=mgpu_utils.is_signed(x_aval.dtype),
12271227
optimized=False,
12281228
)
1229-
value.store_untiled(x_smem)
1229+
value.store_untiled(x_smem, optimized=False)
12301230
return old_value
12311231
case _:
12321232
old_value = mgpu.FragmentedArray.load_strided(

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 12 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,25 +1788,23 @@ def _(val, idx):
17881788
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
17891789
utils.debug_print(fmt_str, *idx, val, uniform=False)
17901790

1791-
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
1791+
def store_untiled(
1792+
self, ref: ir.Value, *, swizzle: int = 16, optimized: bool = True
1793+
):
17921794
if not ir.MemRefType.isinstance(ref.type):
17931795
raise ValueError(ref)
1794-
1795-
def vs_unsupported():
1796-
if not vector_store:
1797-
raise NotImplementedError(
1798-
f"Can't use non-vector stores with layout {self.layout}"
1799-
)
1800-
18011796
match self.layout:
18021797
case WGSplatFragLayout():
1803-
vs_unsupported()
1798+
# All values are the same so swizzle does not affect anything here.
18041799
self._store_untiled_splat(ref)
18051800
case WGStridedFragLayout():
1806-
vs_unsupported()
1801+
if swizzle != 16:
1802+
raise NotImplementedError
18071803
self._store_untiled_wg_strided(ref)
18081804
case TiledLayout():
1809-
self._store_untiled_tiled(ref, vector_store=vector_store)
1805+
ref_shape = ir.MemRefType(ref.type).shape
1806+
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
1807+
self.store_tiled(ref, swizzle=swizzle, optimized=optimized)
18101808
case _:
18111809
raise NotImplementedError(self.layout)
18121810

@@ -1861,61 +1859,15 @@ def _store_untiled_wg_strided(self, ref: ir.Value):
18611859
for idx, reg in zip(idxs, self.registers.flat):
18621860
vector.store(reg, ref_, idx)
18631861

1864-
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
1865-
"""Stores an array with a tiled layout. Not optimized at the moment."""
1866-
if utils.bitwidth(self.mlir_dtype) < 8:
1867-
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
1868-
i32 = ir.IntegerType.get_signless(32)
1869-
layout = self.layout
1870-
assert isinstance(layout, TiledLayout)
1871-
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
1872-
if vector_store and ref_strides[layout.vector_dim] != 1:
1873-
raise NotImplementedError(
1874-
"Can't use vector stores with non-unit minormost stride"
1875-
)
1876-
strides = layout.tiling.tile_strides(ref_strides)
1877-
smem_space = ir.Attribute.parse("#gpu.address_space<workgroup>")
1878-
ref_space = ir.MemRefType(ref.type).memory_space
1879-
memory_space = None
1880-
if str(ref_space) == str(smem_space):
1881-
memory_space = 3
1882-
elif ref_space:
1883-
raise NotImplementedError(f"Unexpected ref space {ref_space}")
1884-
ptr = utils.memref_ptr(ref, memory_space=memory_space)
1885-
# Fold warp and lane offsets into the pointer once, since they are dynamic.
1886-
dyn_strides = [
1887-
arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :]
1888-
]
1889-
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides)
1890-
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides)
1891-
dyn_offset = arith.addi(warp_offset, lane_offset)
1892-
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
1893-
# All warp tile offsets are static and can be fused into the store.
1894-
for tile_idx, reg in np.ndenumerate(self.registers):
1895-
if vector_store:
1896-
elems = [reg]
1897-
else:
1898-
index = ir.IndexType.get()
1899-
elems = [
1900-
vector.extractelement(reg, position=c(i, index))
1901-
for i in range(ir.VectorType(reg.type).shape[0])
1902-
]
1903-
for i, e in enumerate(elems):
1904-
tile_idx_local = list(tile_idx)
1905-
tile_idx_local[layout.vector_dim] += i
1906-
tile_idx_local = list(tile_idx_local)
1907-
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
1908-
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
1909-
llvm.store(e, reg_ptr)
1910-
1911-
def store_tiled(self, ref, swizzle: int | None):
1862+
def store_tiled(self, ref, swizzle: int | None, optimized: bool = True):
19121863
if not isinstance(self.layout, TiledLayout):
19131864
raise NotImplementedError(self.layout)
19141865
layout, shape = self.layout, self.shape
19151866
# Note that the loop below will "race" for layouts that replicate data.
19161867
# However, in that case all of the racing writes store the same data, which
19171868
# is ok in the CUDA memory model.
1918-
for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape):
1869+
stores = self.transfer_tiled2(ref, swizzle, layout, shape, optimized)
1870+
for get, _, ptr in stores:
19191871
llvm.store(get(self.registers), ptr)
19201872

19211873
@classmethod

tests/mosaic/gpu_test.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -489,19 +489,12 @@ def get_packed_shape(strides, shape):
489489

490490
class WGMMALayoutTest(TestCase):
491491

492-
@parameterized.product(dtype=[jnp.float16, jnp.float32],
493-
transposed_smem=[False, True])
494-
def test_store_untiled(self, dtype, transposed_smem):
492+
@parameterized.product(dtype=[jnp.float16, jnp.float32])
493+
def test_store_untiled(self, dtype):
495494
def kernel(ctx, out, _):
496495
del ctx
497-
if transposed_smem:
498-
out = memref_transpose(out, (1, 0))
499-
iota_tensor(64, 64, dtype).store_untiled(
500-
out, vector_store=not transposed_smem
501-
)
496+
iota_tensor(64, 64, dtype).store_untiled(out, optimized=False)
502497
expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64)
503-
if transposed_smem:
504-
expected = expected.T
505498
iota = mgpu.as_gpu_kernel(
506499
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
507500
)()
@@ -749,7 +742,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
749742
acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle)
750743
nvvm.wgmma_commit_group_sync_aligned()
751744
nvvm.wgmma_wait_group_sync_aligned(0)
752-
acc.value.store_untiled(out)
745+
acc.value.store_untiled(out, optimized=False)
753746

754747
def quantize(x):
755748
# Quantize the input to avoid rounding when feeding the WGMMA
@@ -821,7 +814,7 @@ def kernel(ctx, rhs, out, rhs_smem):
821814
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
822815
nvvm.wgmma_commit_group_sync_aligned()
823816
nvvm.wgmma_wait_group_sync_aligned(0)
824-
acc.value.store_untiled(out)
817+
acc.value.store_untiled(out, optimized=False)
825818

826819
y_shape = (n, k) if rhs_transpose else (k, n)
827820
y = self.prng.uniform(-1, 1, y_shape).astype(dtype)
@@ -881,7 +874,7 @@ def kernel(ctx, rhs, out, smem):
881874
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
882875
nvvm.wgmma_commit_group_sync_aligned()
883876
nvvm.wgmma_wait_group_sync_aligned(0)
884-
acc.value.store_untiled(out)
877+
acc.value.store_untiled(out, optimized=False)
885878

886879
jax_dtype = jnp.float16
887880
y_shape = (n, k) if rhs_transpose else (k, n)
@@ -1042,7 +1035,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
10421035
)
10431036
tcgen05.commit_arrive(barriers[2])
10441037
barriers[2].wait(for_tensor_core=True)
1045-
acc[:].store_untiled(out)
1038+
acc[:].store_untiled(out, optimized=False)
10461039

10471040
x_shape = (k, m) if lhs_transpose else (m, k)
10481041
x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype)
@@ -1145,7 +1138,7 @@ def kernel(ctx, lhs, rhs, out, scratch):
11451138
tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx)
11461139
barriers[2].wait(for_tensor_core=True)
11471140
m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile)
1148-
acc[:].store_untiled(memref_slice(out, m_slice))
1141+
acc[:].store_untiled(memref_slice(out, m_slice), optimized=False)
11491142

11501143
in_finfo = jnp.finfo(in_jax_dtype)
11511144
exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant
@@ -1198,7 +1191,7 @@ def kernel(ctx, dst, scratch):
11981191
final_arr = arr + mgpu.FragmentedArray.load_strided(
11991192
tmp, is_signed=False
12001193
)
1201-
final_arr.store_untiled(memref_slice(dst, 0))
1194+
final_arr.store_untiled(memref_slice(dst, 0), optimized=False)
12021195
scf.yield_([])
12031196
with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block):
12041197
barriers[0].wait()
@@ -1209,7 +1202,7 @@ def kernel(ctx, dst, scratch):
12091202
barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp.
12101203
arr.store_untiled(tmp)
12111204
barriers[1].arrive() # Signal that tmp is ready.
1212-
final_arr.store_untiled(memref_slice(dst, 1))
1205+
final_arr.store_untiled(memref_slice(dst, 1), optimized=False)
12131206
scf.yield_([])
12141207
out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32)
12151208
y = mgpu.as_gpu_kernel(
@@ -1670,7 +1663,7 @@ def kernel(ctx, dst, _):
16701663
mlir_dtype = utils.dtype_to_ir_type(dtype)
16711664
iota = iota_tensor(m, n, dtype)
16721665
rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype)
1673-
op(iota, rhs).store_untiled(dst)
1666+
op(iota, rhs).store_untiled(dst, optimized=False)
16741667
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
16751668
result = mgpu.as_gpu_kernel(
16761669
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
@@ -1716,7 +1709,7 @@ def test_division(self, op, dtype, m=64, n=32):
17161709

17171710
def kernel(ctx, dst, _):
17181711
iota = iota_tensor(m, n, dtype)
1719-
op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst)
1712+
op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False)
17201713

17211714
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
17221715
result = mgpu.as_gpu_kernel(
@@ -1746,14 +1739,14 @@ def kernel(ctx, dst, _):
17461739
rhs = 0 if rhs_is_literal else iota + 1
17471740
res = op(iota, rhs)
17481741
assert not res.is_signed
1749-
res.astype(i8, is_signed=False).store_untiled(dst)
1742+
res.astype(i8, is_signed=False).store_untiled(dst, optimized=False)
17501743

17511744
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8)
17521745
result = mgpu.as_gpu_kernel(
17531746
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
17541747
)()
17551748
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
1756-
rhs = rhs = 0 if rhs_is_literal else iota + 1
1749+
rhs = 0 if rhs_is_literal else iota + 1
17571750
np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8))
17581751

17591752
def test_foreach_wgmma_row_array(self):
@@ -1784,22 +1777,25 @@ def _(v, idx):
17841777
def test_foreach(self):
17851778
dtype = jnp.int32
17861779
swizzle = 128
1787-
tile = 64, swizzle // jnp.dtype(dtype).itemsize
1780+
tiling = (8, swizzle // jnp.dtype(dtype).itemsize)
17881781
shape = 128, 192
1789-
tiled_shape = mgpu.tile_shape(shape, tile)
17901782
mlir_dtype = utils.dtype_to_ir_type(dtype)
17911783
cst = 9999
17921784
def causal(val, idx):
17931785
row, col = idx
17941786
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
17951787
return arith.select(mask, val, c(cst, mlir_dtype))
17961788

1797-
tiling = mgpu.TileTransform(tile)
17981789
def kernel(ctx, dst, smem):
17991790
x = iota_tensor(shape[0], shape[1], dtype)
1800-
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
1791+
x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128)
18011792
mgpu.commit_shared()
1802-
ctx.async_copy(src_ref=smem, dst_ref=dst)
1793+
ctx.async_copy(
1794+
src_ref=smem,
1795+
dst_ref=dst,
1796+
gmem_transform=mgpu.TileTransform(tiling),
1797+
swizzle=128,
1798+
)
18031799
ctx.await_async_copy(0)
18041800

18051801
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
@@ -1809,7 +1805,7 @@ def kernel(ctx, dst, smem):
18091805
(128, 1, 1),
18101806
(),
18111807
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
1812-
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
1808+
jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype),
18131809
)()
18141810
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
18151811
np.testing.assert_array_equal(result, expected)
@@ -1821,7 +1817,7 @@ def kernel(ctx, dst, smem):
18211817
def test_bitwise(self, op, dtype, m=64, n=8):
18221818
def kernel(ctx, dst, _):
18231819
iota = iota_tensor(m, n, dtype)
1824-
op(iota, iota + 1).store_untiled(dst)
1820+
op(iota, iota + 1).store_untiled(dst, optimized=False)
18251821

18261822
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
18271823
result = mgpu.as_gpu_kernel(
@@ -1845,7 +1841,7 @@ def test_unary(self, ops, dtype, m=64, n=32):
18451841

18461842
def kernel(ctx, dst, _):
18471843
iota = iota_tensor(m, n, dtype)
1848-
op(iota).store_untiled(dst)
1844+
op(iota).store_untiled(dst, optimized=False)
18491845

18501846
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
18511847
result = mgpu.as_gpu_kernel(
@@ -1858,7 +1854,7 @@ def test_select(self, m=64, n=32):
18581854

18591855
def kernel(ctx, dst, _):
18601856
iota = iota_tensor(m, n, jnp.int32)
1861-
(iota < 16).select(iota * 2, iota * 3).store_untiled(dst)
1857+
(iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False)
18621858

18631859
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32)
18641860
result = mgpu.as_gpu_kernel(
@@ -1881,7 +1877,7 @@ def test_math(self, ops, approx, m=64, n=32):
18811877
op, np_op = ops
18821878
def kernel(ctx, dst, _):
18831879
iota = iota_tensor(m, n, jnp.float32)
1884-
op(iota).store_untiled(dst)
1880+
op(iota).store_untiled(dst, optimized=False)
18851881
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
18861882
result = mgpu.as_gpu_kernel(
18871883
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
@@ -1902,7 +1898,7 @@ def kernel(ctx, src, dst, scratch):
19021898
src, is_signed=utils.is_signed(dtype)
19031899
)
19041900
acc = src.reduce_sum(scratch).broadcast((m,))
1905-
acc.store_untiled(dst)
1901+
acc.store_untiled(dst, optimized=False)
19061902

19071903
in_shape = jax.ShapeDtypeStruct((m, n), dtype)
19081904
out_shape = jax.ShapeDtypeStruct((m,), dtype)
@@ -1930,7 +1926,7 @@ def kernel(ctx, dst, _):
19301926
is_signed=utils.is_signed(dtype),
19311927
)
19321928
acc = src.reduce_sum().broadcast((m,))
1933-
acc.store_untiled(dst)
1929+
acc.store_untiled(dst, optimized=False)
19341930

19351931
kernel_fn = mgpu.as_gpu_kernel(
19361932
kernel,
@@ -1950,7 +1946,7 @@ def kernel(ctx, dst, _):
19501946
def test_reduce(self, op, m=64, n=32):
19511947
def kernel(ctx, dst, _):
19521948
iota = iota_tensor(m, n, jnp.float32)
1953-
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst)
1949+
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst, optimized=False)
19541950
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
19551951
result = mgpu.as_gpu_kernel(
19561952
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
@@ -1971,7 +1967,7 @@ def kernel(ctx, dst, _):
19711967
cte = c(1, iota.mlir_dtype)
19721968
cte_arr = mgpu.FragmentedArray.splat(cte, ())
19731969
cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n))
1974-
(iota + cte_arr).store_untiled(dst)
1970+
(iota + cte_arr).store_untiled(dst, optimized=False)
19751971
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
19761972
result = mgpu.as_gpu_kernel(
19771973
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
@@ -1986,7 +1982,7 @@ def kernel(ctx, dst, _):
19861982
t = mgpu.FragmentedArray.splat(
19871983
v, (128,), mgpu.WGMMA_ROW_LAYOUT
19881984
)
1989-
t.broadcast_minor(32).store_untiled(dst)
1985+
t.broadcast_minor(32).store_untiled(dst, optimized=False)
19901986
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
19911987
result = mgpu.as_gpu_kernel(
19921988
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
@@ -2005,7 +2001,7 @@ def kernel(ctx, src, dst, _):
20052001
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
20062002
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
20072003
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
2008-
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
2004+
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False)
20092005

20102006
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
20112007
inp = jnp.ones_like(out_shape) * 3.14
@@ -2077,7 +2073,7 @@ def kernel(ctx, gmem_input, gmem_output, _):
20772073
t = mgpu.FragmentedArray.load_untiled(
20782074
gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False
20792075
)
2080-
t.broadcast_major(m).store_untiled(gmem_output)
2076+
t.broadcast_major(m).store_untiled(gmem_output, optimized=False)
20812077

20822078
inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16)
20832079
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16)
@@ -2114,7 +2110,7 @@ def kernel(ctx, inp, out, smem):
21142110
del ctx, smem
21152111
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
21162112
assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length]
2117-
arr.astype(mlir_dtype_to).store_untiled(out)
2113+
arr.astype(mlir_dtype_to).store_untiled(out, optimized=False)
21182114

21192115
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
21202116
x = jnp.tile(x, reg_length // 2)
@@ -2190,7 +2186,7 @@ def test_convert_bool_to_u8(self):
21902186
def kernel(ctx, dst, _):
21912187
i8 = ir.IntegerType.get_signless(8)
21922188
iota = iota_tensor(m, n, jnp.uint8)
2193-
(iota > 10).astype(i8, is_signed=False).store_untiled(dst)
2189+
(iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False)
21942190

21952191
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8)
21962192
result = mgpu.as_gpu_kernel(
@@ -2318,7 +2314,7 @@ def kernel(ctx, dst, _):
23182314
)
23192315
self.assertEqual(tiled.shape, shape)
23202316
self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype)
2321-
tiled.store_untiled(dst)
2317+
tiled.store_untiled(dst, optimized=False)
23222318
ty = jax.ShapeDtypeStruct(shape, dtype)
23232319
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ())
23242320
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)

0 commit comments

Comments
 (0)