Skip to content

Commit bd66f52

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add a bank-conflict checker to tiled transfer + transfer planner
Instead of only allowing a fixed set of layouts that we've hand verified as bank-conflict free, we now simulate the transactions performed within each warp and verify that no bank conflicts happen. If we detect that the simple schedule does not work out, we attempt to partition the threads in a warp into two groups and stagger the transfers in a way that lets us avoid conflicts. This allows us to match the hand-designed transfer schedule I wrote for 32-bit types, and even generalizes it to more cases automatically (e.g. swizzle=32). PiperOrigin-RevId: 701919158
1 parent e124c05 commit bd66f52

File tree

2 files changed

+225
-25
lines changed

2 files changed

+225
-25
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 208 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import math
2222
from collections.abc import Callable
23-
from typing import Iterable, Sequence, TypeVar
23+
from typing import Iterable, Protocol, Sequence, TypeVar
2424

2525
import jax
2626
from jaxlib.mlir import ir
@@ -42,6 +42,8 @@
4242
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
4343
WARP_SIZE = 32
4444
WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE
45+
SMEM_BANKS = 32
46+
SMEM_BANK_BYTES = 4
4547
c = utils.c
4648

4749

@@ -1455,11 +1457,15 @@ def load_tiled(
14551457
raise ValueError("Tiled reference must have even rank")
14561458
tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],))
14571459
shape = tiling.untile_shape(tiled_shape)
1458-
registers = np.full(layout.registers_shape(shape), None, dtype=object)
1460+
zero = (
1461+
vector.splat(
1462+
ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype)
1463+
),
1464+
)
1465+
registers = np.full(layout.registers_shape(shape), zero, dtype=object)
14591466
reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type)
14601467
for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape):
14611468
update(registers, llvm.load(reg_ty, ptr))
1462-
assert all(r is not None for r in registers.flat)
14631469
case WGMMAFragLayout():
14641470
bw = mgpu.bytewidth(dtype)
14651471
m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape
@@ -1611,19 +1617,18 @@ def transfer_tiled2(
16111617

16121618
tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
16131619
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
1620+
lane_strides = [tiled_strides[d] for d in layout.lane_dims]
1621+
lane_shape = [tiled_shape[d] for d in layout.lane_dims]
16141622
if tiled_strides[layout.vector_dim] != 1:
16151623
raise ValueError("Stride of the vectorized dimension should be 1")
16161624
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
16171625
tiled_shape[d] = 1
16181626
full_tiling = Tiling((ref_tiling_shape, *tiling.tiles))
16191627
full_layout = dataclasses.replace(layout, tiling=full_tiling)
16201628

1621-
# XXX: This method is still slightly incompete. For example, it does not
1622-
# verify that the vector transfers don't cross swizzle tile boundaries. It
1623-
# also does not guarantee that the transfer pattern does not cause bank
1624-
# conflicts. For that reason, we only allow a select subset of layouts.
1625-
if layout != _tiled_wgmma_layout(shape) or bw > 2:
1626-
raise NotImplementedError("transfer_tiled2 not general enough yet")
1629+
plan = plan_tiled_transfer(
1630+
tiled_shape, tiled_strides, lane_shape, lane_strides, layout, bw, swizzle
1631+
)
16271632

16281633
dyn_tiled_strides = [c(s) for s in tiled_strides]
16291634
lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides)
@@ -1632,27 +1637,45 @@ def transfer_tiled2(
16321637
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
16331638
raise ValueError("Tiled stores can be performed into SMEM")
16341639
ptr = utils.memref_ptr(ref, memory_space=3)
1640+
_as_consts = lambda consts: [c(const) for const in consts.tolist()]
1641+
# This has bits set only for the offset bits that influence swizzling.
1642+
swizzle_mask = swizzle_block_elems - swizzle_tile_elems
16351643
for tile_idx in np.ndindex(*tiled_shape):
1636-
const_offset = sum(i * s for i, s in zip(tile_idx, tiled_strides))
1644+
indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms])
1645+
const_offset = np.dot(indices, tiled_strides)
16371646
# We split the offset into a part that interacts with swizzling and a
16381647
# part that doesn't. This lets us generate better code because constant
16391648
# offsets can be fused into load and store instructions.
1640-
const_offset_swizzle = const_offset % swizzle_block_elems
1649+
const_offset_swizzle = const_offset & swizzle_mask
16411650
const_offset_no_swizzle = const_offset - const_offset_swizzle
1642-
offset_pre_swizzle = arith.addi(dyn_offset, c(const_offset_swizzle))
1651+
offset_pre_swizzle = arith.addi(
1652+
dyn_offset, plan.select(_as_consts(const_offset_swizzle))
1653+
)
16431654
swizzle_group = arith.remui(
16441655
arith.divui(offset_pre_swizzle, c(swizzle_group_elems)),
16451656
c(swizzle_groups_per_block),
16461657
)
16471658
swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems))
16481659
offset = arith.xori(offset_pre_swizzle, swizzle_bits)
16491660
reg_ptr = utils.getelementptr(ptr, [offset], dtype)
1650-
reg_ptr = utils.getelementptr(reg_ptr, [const_offset_no_swizzle], dtype)
1651-
reg_idx = tiling.tile_indices(full_tiling.untile_indices(tile_idx))
1652-
def get_register(regs, reg_idx=reg_idx):
1653-
return regs[reg_idx]
1654-
def update_registers(regs, new, reg_idx=reg_idx):
1655-
regs[reg_idx] = new
1661+
offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle))
1662+
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], dtype)
1663+
reg_idxs = [
1664+
tiling.tile_indices(full_tiling.untile_indices(idx))
1665+
for idx in indices.tolist()
1666+
]
1667+
def get_register(regs, reg_idxs=reg_idxs):
1668+
return plan.select([regs[reg_idx] for reg_idx in reg_idxs])
1669+
def update_registers(regs, new, reg_idxs=reg_idxs):
1670+
# TODO(apaszke): If the staggering forms a permutation with a small
1671+
# cycle length, then instead of blending at each step we could construct
1672+
# a small routing network (kind of like a sorting network) to fix up
1673+
# each cycle separately after all the loads are performed.
1674+
# This would be especially useful for dims that are powers of two and
1675+
# staggered by another power of 2, since all cycles are of length 2 (and
1676+
# we could save half the selects).
1677+
for i, reg_idx in enumerate(reg_idxs):
1678+
regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new)
16561679
yield get_register, update_registers, reg_ptr
16571680

16581681
def tree_flatten(self):
@@ -1666,6 +1689,173 @@ def tree_unflatten(cls, aux, flat_registers):
16661689
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
16671690

16681691

1692+
class TransferPlan(Protocol):
1693+
IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]]
1694+
tile_index_transforms: tuple[IndexTransform, ...]
1695+
1696+
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
1697+
"""Selects the value corresponding to the group of the current thread.
1698+
1699+
The argument must be of the same length as tile_index_transforms.
1700+
"""
1701+
raise NotImplementedError
1702+
1703+
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
1704+
"""Returns `new` if the current thread belongs to the given group and `old` otherwise.
1705+
1706+
group_idx must be between 0 and len(tile_index_transforms) - 1.
1707+
"""
1708+
raise NotImplementedError
1709+
1710+
1711+
@dataclasses.dataclass(frozen=True)
1712+
class TrivialTransferPlan(TransferPlan):
1713+
@property
1714+
def tile_index_transforms(self):
1715+
return (lambda x: x,)
1716+
1717+
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
1718+
assert len(group_elems) == 1
1719+
return group_elems[0]
1720+
1721+
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
1722+
assert group_idx == 0
1723+
return new
1724+
1725+
1726+
@dataclasses.dataclass(frozen=True)
1727+
class StaggeredTransferPlan(TransferPlan):
1728+
stagger: int
1729+
dim: int
1730+
size: int
1731+
group_pred: ir.Value
1732+
1733+
@property
1734+
def tile_index_transforms(self):
1735+
dim = self.dim
1736+
def rotate(idx: tuple[int, ...]) -> tuple[int, ...]:
1737+
return (
1738+
*idx[:dim], (idx[dim] + self.stagger) % self.size, *idx[dim + 1 :],
1739+
)
1740+
return (lambda x: x, rotate)
1741+
1742+
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
1743+
assert len(group_elems) == 2
1744+
return arith.select(self.group_pred, group_elems[1], group_elems[0])
1745+
1746+
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
1747+
assert 0 <= group_idx <= 1
1748+
sides = [old, new] if group_idx == 0 else [new, old]
1749+
return arith.select(self.group_pred, *sides)
1750+
1751+
1752+
def plan_tiled_transfer(
1753+
tiled_shape: Sequence[int],
1754+
tiled_strides: Sequence[int],
1755+
lane_shape: Sequence[int],
1756+
lane_strides: Sequence[int],
1757+
layout: TiledLayout,
1758+
bw: int,
1759+
swizzle: int,
1760+
) -> TransferPlan:
1761+
i32 = ir.IntegerType.get_signless(32)
1762+
c = lambda x: arith.constant(i32, x)
1763+
swizzle_tile_elems = 16 // bw
1764+
swizzle_group_elems = 128 // bw
1765+
# Below, all calculations are in elements, not in bytes, since it should
1766+
# generalize better to sub-byte types.
1767+
# Here, we verify two conditions:
1768+
# 1. Each vector transfer only accesses addresses that fall within a single
1769+
# swizzle tile (if not we'd need to split it and swizzle parts differently).
1770+
transfer_alignment = math.gcd(*(
1771+
s
1772+
for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape)))
1773+
if d > 1 or i in {layout.warp_dim, *layout.lane_dims}
1774+
))
1775+
if (
1776+
swizzle_tile_elems % transfer_alignment
1777+
and layout.vector_length <= transfer_alignment
1778+
):
1779+
raise ValueError(
1780+
"Failed to prove that vector transfers don't cross swizzle tile"
1781+
" boundaries. This check is incomplete, and does not guarantee that"
1782+
" this is a user error, but it might be." + str(transfer_alignment)
1783+
)
1784+
1785+
# 2. The transfer pattern does not cause bank conflicts.
1786+
# TODO(apaszke): For now, when performing transfers narrower than a bank,
1787+
# we simply narrow each bank to the transfer width. The truth is more likely
1788+
# that bank conflicts only don't occur if the addresses mapping to the same
1789+
# bank are contiguous, but that's a more complicated check to perform.
1790+
transfer_bytes = layout.vector_length * bw
1791+
if transfer_bytes > SMEM_BANK_BYTES * 4:
1792+
raise NotImplementedError
1793+
if bw > SMEM_BANK_BYTES:
1794+
raise NotImplementedError
1795+
smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes)
1796+
num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes)
1797+
elems_per_bank = smem_bank_bytes // bw
1798+
num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1)
1799+
wavefront_lanes = WARP_SIZE // num_wavefronts
1800+
1801+
lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides)
1802+
def has_bank_conflicts(tile_idx_transform):
1803+
tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape)
1804+
tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims]
1805+
lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims]
1806+
assert lane_tile_idx.shape[1] in {1, WARP_SIZE}
1807+
lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides)
1808+
offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes]
1809+
assert offsets.shape[-1] == WARP_SIZE
1810+
swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16)
1811+
swizzle_bits = swizzle_groups * swizzle_tile_elems
1812+
lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks
1813+
wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes)
1814+
# Order of threads within the wavefront is unimportant.
1815+
wavefront_banks = np.sort(wavefront_banks, axis=-1)
1816+
# There are no conflicts if each wavefront only contains unique banks.
1817+
return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1])
1818+
1819+
# We don't need any special treatment if there are no conflicts when each lane
1820+
# transfers the same tile at a time.
1821+
if not has_bank_conflicts(lambda tile_idx: tile_idx):
1822+
return TrivialTransferPlan()
1823+
1824+
# Otherwise, we will try to partition the lanes into two groups and have
1825+
# each group store to different tile. The only tile dimensions that can help
1826+
# us with bank conflicts are those that have multiple elements and a stride
1827+
# that's not a multiple of the number of banks.
1828+
#
1829+
# Note that the code is set up so that we could also consider partitioning
1830+
# the lanes into more groups, but the selects will become more expensive if
1831+
# we do that. It's a possibility we have if we need it.
1832+
candidate_dims = (
1833+
i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape))
1834+
if d > 1 and s % (SMEM_BANKS * elems_per_bank)
1835+
)
1836+
for dim in candidate_dims:
1837+
for group_stride in (1, 2, 4, 8, 16):
1838+
# We change the group assignment each group_stride lanes.
1839+
lane_id = np.arange(WARP_SIZE)[:, None]
1840+
lane_group = (lane_id // group_stride) % 2
1841+
# We only consider a transformation where the second group stores to a
1842+
# tile that's a constant offset (modulo dim size) from the first one.
1843+
for stagger in range(1, tiled_shape[dim]):
1844+
offset = np.zeros(len(tiled_shape), np.int64)
1845+
offset[dim] = stagger
1846+
transform = lambda idx: (idx + offset * lane_group) % tiled_shape
1847+
if not has_bank_conflicts(transform):
1848+
# We've found a strategy that avoids bank conflicts!
1849+
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE))
1850+
group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2))
1851+
group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0))
1852+
return StaggeredTransferPlan(
1853+
stagger, dim, tiled_shape[dim], group_pred
1854+
)
1855+
raise ValueError(
1856+
"Failed to synthesize a transfer pattern that avoids bank conflicts"
1857+
)
1858+
16691859
# We allow contractions, to potentially take advantage of FMA instructions.
16701860
# They can change the results, but the precision should only increase.
16711861
def addf(a: ir.Value, b: ir.Value):

tests/mosaic/gpu_test.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,18 +1755,21 @@ def kernel(ctx, dst, _):
17551755
@parameterized.product(
17561756
load_tiled=[False, True],
17571757
store_tiled=[False, True],
1758-
dtype=[jnp.int16],
1758+
dtype=[jnp.int8, jnp.int16, jnp.int32],
17591759
swizzle=[32, 64, 128],
1760-
num_col_tiles=[1, 2, 4],
1760+
num_col_tiles=[1, 2, 3],
17611761
)
17621762
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles):
17631763
mlir_dtype = utils.dtype_to_ir_type(dtype)
1764-
col_tiling = swizzle // bytewidth(mlir_dtype)
1764+
bw = bytewidth(mlir_dtype)
1765+
col_tiling = swizzle // bw
17651766
m, n = 128, col_tiling * num_col_tiles
17661767
tiling = (64, col_tiling)
17671768
tiled_layout = fa._tiled_wgmma_layout((m, n))
17681769
load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT
17691770
store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT
1771+
if (not load_tiled or not store_tiled) and bw == 4 and swizzle == 32:
1772+
self.skipTest("Old code path does not support this")
17701773
def kernel(ctx, in_, out, smems):
17711774
smem_in, smem_out, barrier = smems
17721775
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
@@ -1800,14 +1803,21 @@ def kernel(ctx, in_, out, smems):
18001803
# Verify that we don't use too many registers for the transfers.
18011804
# We verify LDS and STS separately, because they might use two different
18021805
# methods of computing offsets and we don't rely on CSE between them.
1803-
register_pattern = re.compile(r"(R[0-9]+)")
18041806
expected_regs = swizzle // bytewidth(mlir_dtype) // 8
1807+
# When the bytewidth is smaller than 2 the swizzle pattern changes every 2
1808+
# column tiles, so we only need half the registers.
1809+
if load_tiled and store_tiled: # The old code doesn't optimize properly.
1810+
if bytewidth(mlir_dtype) < 2:
1811+
expected_regs //= 2
18051812
for instr in ("STS", "LDS"):
18061813
with self.subTest(instr + " count"):
18071814
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
1808-
chain = itertools.chain.from_iterable
1809-
used_regs = set(chain(register_pattern.findall(addr) for addr in addrs))
1810-
self.assertLen(used_regs, expected_regs)
1815+
def get_reg(addr):
1816+
if (pos := addr.find("+")) != -1:
1817+
return addr[:pos]
1818+
return addr
1819+
used_regs = {get_reg(addr) for addr in addrs}
1820+
self.assertLessEqual(len(used_regs), expected_regs)
18111821

18121822

18131823
if __name__ == "__main__":

0 commit comments

Comments
 (0)