2020import functools
2121import math
2222from collections .abc import Callable
23- from typing import Iterable , Sequence , TypeVar
23+ from typing import Iterable , Protocol , Sequence , TypeVar
2424
2525import jax
2626from jaxlib .mlir import ir
4242WARPGROUP_SIZE = utils .WARPGROUP_SIZE
4343WARP_SIZE = 32
4444WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE
45+ SMEM_BANKS = 32
46+ SMEM_BANK_BYTES = 4
4547c = utils .c
4648
4749
@@ -1455,11 +1457,15 @@ def load_tiled(
14551457 raise ValueError ("Tiled reference must have even rank" )
14561458 tiling = Tiling ((tiled_shape [len (tiled_shape ) // 2 :],))
14571459 shape = tiling .untile_shape (tiled_shape )
1458- registers = np .full (layout .registers_shape (shape ), None , dtype = object )
1460+ zero = (
1461+ vector .splat (
1462+ ir .VectorType .get ((layout .vector_length ,), dtype ), c (0 , dtype )
1463+ ),
1464+ )
1465+ registers = np .full (layout .registers_shape (shape ), zero , dtype = object )
14591466 reg_ty = ir .VectorType .get ((layout .vector_length ,), ref_ty .element_type )
14601467 for _ , update , ptr in cls .transfer_tiled2 (ref , swizzle , layout , shape ):
14611468 update (registers , llvm .load (reg_ty , ptr ))
1462- assert all (r is not None for r in registers .flat )
14631469 case WGMMAFragLayout ():
14641470 bw = mgpu .bytewidth (dtype )
14651471 m_tiles , n_tiles , m_tile_size , n_tile_size = ref_ty .shape
@@ -1611,19 +1617,18 @@ def transfer_tiled2(
16111617
16121618 tiled_strides = list (tiling .tile_strides (tuple (ref_strides )))
16131619 tiled_shape = list (tiling .tile_shape (tuple (ref_ty .shape )))
1620+ lane_strides = [tiled_strides [d ] for d in layout .lane_dims ]
1621+ lane_shape = [tiled_shape [d ] for d in layout .lane_dims ]
16141622 if tiled_strides [layout .vector_dim ] != 1 :
16151623 raise ValueError ("Stride of the vectorized dimension should be 1" )
16161624 for d in (layout .warp_dim , * layout .lane_dims , layout .vector_dim ):
16171625 tiled_shape [d ] = 1
16181626 full_tiling = Tiling ((ref_tiling_shape , * tiling .tiles ))
16191627 full_layout = dataclasses .replace (layout , tiling = full_tiling )
16201628
1621- # XXX: This method is still slightly incompete. For example, it does not
1622- # verify that the vector transfers don't cross swizzle tile boundaries. It
1623- # also does not guarantee that the transfer pattern does not cause bank
1624- # conflicts. For that reason, we only allow a select subset of layouts.
1625- if layout != _tiled_wgmma_layout (shape ) or bw > 2 :
1626- raise NotImplementedError ("transfer_tiled2 not general enough yet" )
1629+ plan = plan_tiled_transfer (
1630+ tiled_shape , tiled_strides , lane_shape , lane_strides , layout , bw , swizzle
1631+ )
16271632
16281633 dyn_tiled_strides = [c (s ) for s in tiled_strides ]
16291634 lane_offset = utils .dyn_dot (full_layout .lane_indices (), dyn_tiled_strides )
@@ -1632,27 +1637,45 @@ def transfer_tiled2(
16321637 if ref_ty .memory_space != ir .Attribute .parse ("#gpu.address_space<workgroup>" ):
16331638 raise ValueError ("Tiled stores can be performed into SMEM" )
16341639 ptr = utils .memref_ptr (ref , memory_space = 3 )
1640+ _as_consts = lambda consts : [c (const ) for const in consts .tolist ()]
1641+ # This has bits set only for the offset bits that influence swizzling.
1642+ swizzle_mask = swizzle_block_elems - swizzle_tile_elems
16351643 for tile_idx in np .ndindex (* tiled_shape ):
1636- const_offset = sum (i * s for i , s in zip (tile_idx , tiled_strides ))
1644+ indices = np .asarray ([f (tile_idx ) for f in plan .tile_index_transforms ])
1645+ const_offset = np .dot (indices , tiled_strides )
16371646 # We split the offset into a part that interacts with swizzling and a
16381647 # part that doesn't. This lets us generate better code because constant
16391648 # offsets can be fused into load and store instructions.
1640- const_offset_swizzle = const_offset % swizzle_block_elems
1649+ const_offset_swizzle = const_offset & swizzle_mask
16411650 const_offset_no_swizzle = const_offset - const_offset_swizzle
1642- offset_pre_swizzle = arith .addi (dyn_offset , c (const_offset_swizzle ))
1651+ offset_pre_swizzle = arith .addi (
1652+ dyn_offset , plan .select (_as_consts (const_offset_swizzle ))
1653+ )
16431654 swizzle_group = arith .remui (
16441655 arith .divui (offset_pre_swizzle , c (swizzle_group_elems )),
16451656 c (swizzle_groups_per_block ),
16461657 )
16471658 swizzle_bits = arith .muli (swizzle_group , c (swizzle_tile_elems ))
16481659 offset = arith .xori (offset_pre_swizzle , swizzle_bits )
16491660 reg_ptr = utils .getelementptr (ptr , [offset ], dtype )
1650- reg_ptr = utils .getelementptr (reg_ptr , [const_offset_no_swizzle ], dtype )
1651- reg_idx = tiling .tile_indices (full_tiling .untile_indices (tile_idx ))
1652- def get_register (regs , reg_idx = reg_idx ):
1653- return regs [reg_idx ]
1654- def update_registers (regs , new , reg_idx = reg_idx ):
1655- regs [reg_idx ] = new
1661+ offset_no_swizzle = plan .select (_as_consts (const_offset_no_swizzle ))
1662+ reg_ptr = utils .getelementptr (reg_ptr , [offset_no_swizzle ], dtype )
1663+ reg_idxs = [
1664+ tiling .tile_indices (full_tiling .untile_indices (idx ))
1665+ for idx in indices .tolist ()
1666+ ]
1667+ def get_register (regs , reg_idxs = reg_idxs ):
1668+ return plan .select ([regs [reg_idx ] for reg_idx in reg_idxs ])
1669+ def update_registers (regs , new , reg_idxs = reg_idxs ):
1670+ # TODO(apaszke): If the staggering forms a permutation with a small
1671+ # cycle length, then instead of blending at each step we could construct
1672+ # a small routing network (kind of like a sorting network) to fix up
1673+ # each cycle separately after all the loads are performed.
1674+ # This would be especially useful for dims that are powers of two and
1675+ # staggered by another power of 2, since all cycles are of length 2 (and
1676+ # we could save half the selects).
1677+ for i , reg_idx in enumerate (reg_idxs ):
1678+ regs [reg_idx ] = plan .select_if_group (i , regs [reg_idx ], new )
16561679 yield get_register , update_registers , reg_ptr
16571680
16581681 def tree_flatten (self ):
@@ -1666,6 +1689,173 @@ def tree_unflatten(cls, aux, flat_registers):
16661689 return cls (_registers = registers , _layout = layout , _is_signed = is_signed )
16671690
16681691
1692+ class TransferPlan (Protocol ):
1693+ IndexTransform = Callable [[tuple [int , ...]], tuple [int , ...]]
1694+ tile_index_transforms : tuple [IndexTransform , ...]
1695+
1696+ def select (self , group_elems : Sequence [ir .Value ]) -> ir .Value :
1697+ """Selects the value corresponding to the group of the current thread.
1698+
1699+ The argument must be of the same length as tile_index_transforms.
1700+ """
1701+ raise NotImplementedError
1702+
1703+ def select_if_group (self , group_idx : int , old : ir .Value , new : ir .Value ) -> ir .Value :
1704+ """Returns `new` if the current thread belongs to the given group and `old` otherwise.
1705+
1706+ group_idx must be between 0 and len(tile_index_transforms) - 1.
1707+ """
1708+ raise NotImplementedError
1709+
1710+
1711+ @dataclasses .dataclass (frozen = True )
1712+ class TrivialTransferPlan (TransferPlan ):
1713+ @property
1714+ def tile_index_transforms (self ):
1715+ return (lambda x : x ,)
1716+
1717+ def select (self , group_elems : Sequence [ir .Value ]) -> ir .Value :
1718+ assert len (group_elems ) == 1
1719+ return group_elems [0 ]
1720+
1721+ def select_if_group (self , group_idx : int , old : ir .Value , new : ir .Value ) -> ir .Value :
1722+ assert group_idx == 0
1723+ return new
1724+
1725+
1726+ @dataclasses .dataclass (frozen = True )
1727+ class StaggeredTransferPlan (TransferPlan ):
1728+ stagger : int
1729+ dim : int
1730+ size : int
1731+ group_pred : ir .Value
1732+
1733+ @property
1734+ def tile_index_transforms (self ):
1735+ dim = self .dim
1736+ def rotate (idx : tuple [int , ...]) -> tuple [int , ...]:
1737+ return (
1738+ * idx [:dim ], (idx [dim ] + self .stagger ) % self .size , * idx [dim + 1 :],
1739+ )
1740+ return (lambda x : x , rotate )
1741+
1742+ def select (self , group_elems : Sequence [ir .Value ]) -> ir .Value :
1743+ assert len (group_elems ) == 2
1744+ return arith .select (self .group_pred , group_elems [1 ], group_elems [0 ])
1745+
1746+ def select_if_group (self , group_idx : int , old : ir .Value , new : ir .Value ) -> ir .Value :
1747+ assert 0 <= group_idx <= 1
1748+ sides = [old , new ] if group_idx == 0 else [new , old ]
1749+ return arith .select (self .group_pred , * sides )
1750+
1751+
1752+ def plan_tiled_transfer (
1753+ tiled_shape : Sequence [int ],
1754+ tiled_strides : Sequence [int ],
1755+ lane_shape : Sequence [int ],
1756+ lane_strides : Sequence [int ],
1757+ layout : TiledLayout ,
1758+ bw : int ,
1759+ swizzle : int ,
1760+ ) -> TransferPlan :
1761+ i32 = ir .IntegerType .get_signless (32 )
1762+ c = lambda x : arith .constant (i32 , x )
1763+ swizzle_tile_elems = 16 // bw
1764+ swizzle_group_elems = 128 // bw
1765+ # Below, all calculations are in elements, not in bytes, since it should
1766+ # generalize better to sub-byte types.
1767+ # Here, we verify two conditions:
1768+ # 1. Each vector transfer only accesses addresses that fall within a single
1769+ # swizzle tile (if not we'd need to split it and swizzle parts differently).
1770+ transfer_alignment = math .gcd (* (
1771+ s
1772+ for i , (s , d ) in enumerate_negative (list (zip (tiled_strides , tiled_shape )))
1773+ if d > 1 or i in {layout .warp_dim , * layout .lane_dims }
1774+ ))
1775+ if (
1776+ swizzle_tile_elems % transfer_alignment
1777+ and layout .vector_length <= transfer_alignment
1778+ ):
1779+ raise ValueError (
1780+ "Failed to prove that vector transfers don't cross swizzle tile"
1781+ " boundaries. This check is incomplete, and does not guarantee that"
1782+ " this is a user error, but it might be." + str (transfer_alignment )
1783+ )
1784+
1785+ # 2. The transfer pattern does not cause bank conflicts.
1786+ # TODO(apaszke): For now, when performing transfers narrower than a bank,
1787+ # we simply narrow each bank to the transfer width. The truth is more likely
1788+ # that bank conflicts only don't occur if the addresses mapping to the same
1789+ # bank are contiguous, but that's a more complicated check to perform.
1790+ transfer_bytes = layout .vector_length * bw
1791+ if transfer_bytes > SMEM_BANK_BYTES * 4 :
1792+ raise NotImplementedError
1793+ if bw > SMEM_BANK_BYTES :
1794+ raise NotImplementedError
1795+ smem_bank_bytes = min (SMEM_BANK_BYTES , transfer_bytes )
1796+ num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes )
1797+ elems_per_bank = smem_bank_bytes // bw
1798+ num_wavefronts = max (transfer_bytes // smem_bank_bytes , 1 )
1799+ wavefront_lanes = WARP_SIZE // num_wavefronts
1800+
1801+ lane_offsets_in_tile = np .dot (list (np .ndindex (* lane_shape )), lane_strides )
1802+ def has_bank_conflicts (tile_idx_transform ):
1803+ tile_idxs = np .unravel_index (np .arange (math .prod (tiled_shape )), tiled_shape )
1804+ tile_idxs = np .expand_dims (np .stack (tile_idxs , 1 ), 1 ) # [#tiles, 1, #dims]
1805+ lane_tile_idx = tile_idx_transform (tile_idxs ) # [#tiles, #lanes/1, #dims]
1806+ assert lane_tile_idx .shape [1 ] in {1 , WARP_SIZE }
1807+ lane_tile_offsets = np .dot (lane_tile_idx , tiled_strides )
1808+ offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes]
1809+ assert offsets .shape [- 1 ] == WARP_SIZE
1810+ swizzle_groups = (offsets // swizzle_group_elems ) % (swizzle // 16 )
1811+ swizzle_bits = swizzle_groups * swizzle_tile_elems
1812+ lane_banks = ((offsets ^ swizzle_bits ) // elems_per_bank ) % num_banks
1813+ wavefront_banks = lane_banks .reshape (- 1 , num_wavefronts , wavefront_lanes )
1814+ # Order of threads within the wavefront is unimportant.
1815+ wavefront_banks = np .sort (wavefront_banks , axis = - 1 )
1816+ # There are no conflicts if each wavefront only contains unique banks.
1817+ return np .any (wavefront_banks [..., 1 :] == wavefront_banks [..., :- 1 ])
1818+
1819+ # We don't need any special treatment if there are no conflicts when each lane
1820+ # transfers the same tile at a time.
1821+ if not has_bank_conflicts (lambda tile_idx : tile_idx ):
1822+ return TrivialTransferPlan ()
1823+
1824+ # Otherwise, we will try to partition the lanes into two groups and have
1825+ # each group store to different tile. The only tile dimensions that can help
1826+ # us with bank conflicts are those that have multiple elements and a stride
1827+ # that's not a multiple of the number of banks.
1828+ #
1829+ # Note that the code is set up so that we could also consider partitioning
1830+ # the lanes into more groups, but the selects will become more expensive if
1831+ # we do that. It's a possibility we have if we need it.
1832+ candidate_dims = (
1833+ i for i , (s , d ) in enumerate (zip (tiled_strides , tiled_shape ))
1834+ if d > 1 and s % (SMEM_BANKS * elems_per_bank )
1835+ )
1836+ for dim in candidate_dims :
1837+ for group_stride in (1 , 2 , 4 , 8 , 16 ):
1838+ # We change the group assignment each group_stride lanes.
1839+ lane_id = np .arange (WARP_SIZE )[:, None ]
1840+ lane_group = (lane_id // group_stride ) % 2
1841+ # We only consider a transformation where the second group stores to a
1842+ # tile that's a constant offset (modulo dim size) from the first one.
1843+ for stagger in range (1 , tiled_shape [dim ]):
1844+ offset = np .zeros (len (tiled_shape ), np .int64 )
1845+ offset [dim ] = stagger
1846+ transform = lambda idx : (idx + offset * lane_group ) % tiled_shape
1847+ if not has_bank_conflicts (transform ):
1848+ # We've found a strategy that avoids bank conflicts!
1849+ lane_idx = arith .remui (utils .thread_idx (), c (WARP_SIZE ))
1850+ group_idx = arith .remui (arith .divui (lane_idx , c (group_stride )), c (2 ))
1851+ group_pred = arith .cmpi (arith .CmpIPredicate .ne , group_idx , c (0 ))
1852+ return StaggeredTransferPlan (
1853+ stagger , dim , tiled_shape [dim ], group_pred
1854+ )
1855+ raise ValueError (
1856+ "Failed to synthesize a transfer pattern that avoids bank conflicts"
1857+ )
1858+
16691859# We allow contractions, to potentially take advantage of FMA instructions.
16701860# They can change the results, but the precision should only increase.
16711861def addf (a : ir .Value , b : ir .Value ):
0 commit comments