Skip to content

Commit 16c3c1a

Browse files
joviliastantiagainstscxiao
authored andcommitted
[AMD] Fix packed fp16 atomic optimization conditions (triton-lang#5839)
We need to make sure threads in a pair are both active and the address is aligned to 4 bytes. --------- Signed-off-by: Ilya Veselov <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Shucai Xiao <[email protected]>
1 parent 3efde92 commit 16c3c1a

File tree

2 files changed

+168
-33
lines changed

2 files changed

+168
-33
lines changed

python/test/unit/language/test_core.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,79 @@ def torch_to_triton_dtype(t):
16551655
np.testing.assert_equal(old_ref, to_numpy(old_tri))
16561656

16571657

1658+
@pytest.mark.interpreter
1659+
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
1660+
for size in [2, 4, 8, 32, 64, 128]
1661+
for num_ctas in num_ctas_list
1662+
for dtype_x_str in ['float16']])
1663+
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
1664+
1665+
@triton.jit
1666+
def kernel(X, val, NUM: tl.constexpr):
1667+
off = tl.arange(0, NUM)
1668+
offset = off[:, None] * NUM + off[None, :]
1669+
val = tl.load(val + offset)
1670+
tl.atomic_add(X + offset // 2, val)
1671+
1672+
shape = (size // 2, size)
1673+
x = torch.zeros(shape, dtype=getattr(torch, dtype_x_str), device=device)
1674+
val = torch.randn((size**2), dtype=getattr(torch, dtype_x_str), device=device)
1675+
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
1676+
ref = val[0::2] + val[1::2]
1677+
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
1678+
1679+
1680+
@pytest.mark.interpreter
1681+
@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str",
1682+
[(shape, idx_order, mask_step, num_ctas, dtype_x_str)
1683+
for shape in [(2, 2), (5, 5), (6, 6), (8, 8)]
1684+
for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random']
1685+
for mask_step in range(1, 5)
1686+
for num_ctas in num_ctas_list
1687+
for dtype_x_str in ['float16']])
1688+
def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device):
1689+
check_type_supported(dtype_x_str, device)
1690+
if is_interpreter():
1691+
pytest.skip("not supported in the interpreter")
1692+
1693+
@triton.jit
1694+
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
1695+
xoffset = tl.program_id(0) * XBLOCK
1696+
x_idx = xoffset + tl.arange(0, XBLOCK)[:]
1697+
mask = x_idx < shape0 * shape1
1698+
mask = mask and (x_idx % mask_step != 0)
1699+
idx_base = shape1 * (x_idx // shape1)
1700+
idx_offset = tl.load(idx_ptr + x_idx, mask)
1701+
in_elem = tl.load(in_ptr + x_idx, mask)
1702+
tl.atomic_add(out_ptr + (idx_offset + idx_base), in_elem, mask, sem='relaxed')
1703+
1704+
shape0, shape1 = shape
1705+
idx_row = torch.arange(0, shape1, device=device)
1706+
if idx_order == 'increase':
1707+
idx = torch.stack([idx_row.repeat_interleave(i + 1)[:shape1] for i in range(shape0)])
1708+
if idx_order == 'decrease':
1709+
idx = torch.stack([idx_row.flip(0).repeat_interleave(i + 1)[:shape1] for i in range(shape0)])
1710+
if idx_order == 'random_no_duplication':
1711+
idx = torch.stack([torch.randperm(shape1, device=device) for _ in idx_row])
1712+
if idx_order == 'random':
1713+
idx = torch.randint(0, shape1, size=(shape0, shape1), device=device)
1714+
1715+
val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
1716+
dst = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
1717+
1718+
dst_ref = dst.clone()
1719+
1720+
cnt = 0
1721+
for i, row in enumerate(idx):
1722+
for j, elem in enumerate(row):
1723+
if cnt % mask_step != 0:
1724+
dst_ref[i][elem] += val[i][j]
1725+
cnt += 1
1726+
1727+
kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)
1728+
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)
1729+
1730+
16581731
@pytest.mark.interpreter
16591732
@pytest.mark.parametrize("num_ctas", num_ctas_list)
16601733
def test_tensor_atomic_rmw_block(num_ctas, device):

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 95 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ struct LoadStoreConversionBase {
176176
return axisAnalysisPass.getMaskAlignment(mask);
177177
}
178178

179-
std::optional<const std::string>
180-
getAMDGPUMemScopeStr(MemSyncScope scope) const {
179+
std::optional<const char *> getAMDGPUMemScopeStr(MemSyncScope scope) const {
181180
// See: https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
182181
auto scopeStr = "";
183182
switch (scope) {
@@ -1295,44 +1294,43 @@ struct AtomicRMWOpConversion
12951294
// tt::atomicRmwOp(%ptr, %val, %mask):
12961295
// 0. Group thread by pairs. Master thread is (tid % 2 == 0);
12971296
// 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
1298-
// all the masters recieve value from secondary threads;
1297+
// all the masters receive value from secondary threads;
12991298
// 2. Take into account parity in the %mask value, build control flow
13001299
// structures according to it;
13011300
// 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
13021301
// 4. All the threads send result of generated operation to (tid + 1) thread
1303-
// via dppUpdateOp shl, so all secondary thread also recieve their
1302+
// via dppUpdateOp shl, so all secondary thread also receive their
13041303
// result.
13051304
//
13061305
// This approach enables us to use half the active threads committing atomic
13071306
// requests to avoid generating of code providing unified access to f16
1308-
// element and reduce contantion.
1307+
// element and reduce contention.
13091308
bool useDppForPackedF16 = false;
13101309
// tensor
13111310
if (tensorTy) {
13121311
auto valTy = cast<RankedTensorType>(val.getType());
13131312
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
13141313
unsigned availableVecSize = isF16Ty ? 2 : 1;
13151314
vec = std::min<unsigned>(vec, availableVecSize);
1316-
// Force F16 packing in the case it's not comming in as packed, but the
1315+
// Force F16 packing in the case it's not coming in as packed, but the
13171316
// ISA can support packed atomic instructions.
13181317
useDppForPackedF16 =
13191318
supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) &&
1320-
vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
1319+
vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD &&
1320+
!enableIntraWaveReduce;
13211321
// mask
13221322
numElems = tensorTy.getNumElements();
13231323
}
1324-
Value mask = b.int_val(1, 1);
1324+
Value mask = b.true_val();
13251325
auto tid = getThreadId(rewriter, loc);
13261326
mask = b.and_(mask, b.icmp_slt(b.mul(tid, b.i32_val(elemsPerThread)),
13271327
b.i32_val(numElems)));
1328-
if (useDppForPackedF16)
1329-
mask = b.and_(mask, b.icmp_eq(b.urem(tid, b.i32_val(2)), b.i32_val(0)));
13301328

13311329
auto memOrdering = op.getSem();
13321330
auto scope = op.getScope();
13331331
auto atomicMemOrdering = getMemoryOrdering(memOrdering);
13341332

1335-
auto scopeStr = getAMDGPUMemScopeStr(scope);
1333+
std::optional<const char *> scopeStr = getAMDGPUMemScopeStr(scope);
13361334
if (!scopeStr)
13371335
return rewriter.notifyMatchFailure(op, "Unknown AMDGPU memory scope");
13381336

@@ -1346,24 +1344,59 @@ struct AtomicRMWOpConversion
13461344
// elemsPerThread.
13471345
Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask;
13481346

1347+
Value i64Ones = b.i64_val(~uint64_t(0));
1348+
Value i64Zeros = b.i64_val(0);
13491349
Value operand;
1350+
Value rightNeighbourPtr;
1351+
Value enablePackedOpt;
13501352
if (useDppForPackedF16) {
1353+
Value isOddI32 = b.urem(tid, b.i32_val(2));
1354+
// First check if odd threads hold adjacent ptrs to even ones.
1355+
Value castedAddr = b.ptrtoint(i64_ty, rmwPtr);
1356+
// Set casted addr to all ones if the thread is disabled.
1357+
castedAddr = b.select(rmwMask, castedAddr, i64Ones);
1358+
13511359
// Move %val to left neighbour to proceed packed atomic further.
13521360
Value packedVal = b.null(packF16Ty);
1353-
packedVal = b.insert_element(packF16Ty, packedVal, valElements[i],
1354-
b.i32_val(0));
1355-
// Pack to i32 type to simplify transaction
1361+
packedVal =
1362+
b.insert_element(packF16Ty, packedVal, valElements[i], isOddI32);
1363+
// Pack to i32 type to simplify transaction.
13561364
packedVal = b.bitcast(packedVal, i32_ty);
1365+
// Zero operands for disabled threads to make addition no op.
1366+
packedVal = b.select(rmwMask, packedVal, b.i32_val(0));
13571367
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
1368+
1369+
Value rightNeighbourAddr =
1370+
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
1371+
1372+
// Packing optimization only supported if following conditions are true:
1373+
// 1. address is aligned by 4 bytes
1374+
// 2. right neighbour has adjacent address
1375+
// 3. both threads are active
1376+
Value isAligned =
1377+
b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
1378+
Value neighbourAddrAdjacent = b.icmp_eq(
1379+
rightNeighbourAddr,
1380+
b.add(castedAddr,
1381+
b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
1382+
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
1383+
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
1384+
enablePackedOpt =
1385+
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
1386+
1387+
// Enable only the even threads.
1388+
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
1389+
// If one of the threads is disabled, use the neighbour's addr.
1390+
rightNeighbourAddr =
1391+
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
1392+
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
1393+
1394+
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
1395+
13581396
// Unpack results back
1359-
Value unpackedDppRes = b.bitcast(dppMoveRes, packF16Ty);
1360-
operand = b.undef(packF16Ty);
1361-
operand =
1362-
b.insert_element(packF16Ty, operand, valElements[i], b.i32_val(0));
1363-
operand = b.insert_element(
1364-
packF16Ty, operand,
1365-
b.extract_element(valueElemTy, unpackedDppRes, b.i32_val(0)),
1366-
b.i32_val(1));
1397+
rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
1398+
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
1399+
operand = b.bitcast(b.or_(packedVal, dppMoveRes), packF16Ty);
13671400
} else if (vec == 1) {
13681401
operand = valElements[i];
13691402
} else {
@@ -1388,15 +1421,47 @@ struct AtomicRMWOpConversion
13881421
rewriter.setInsertionPointToEnd(atomicBlock);
13891422
auto maybeKind = matchAtomicOp(atomicRmwAttr);
13901423
Value atom;
1424+
Value isVecOp;
13911425
if (enableIntraWaveReduce) {
13921426
atom = atomicIntraWaveReduce(rewriter, rmwPtr, operand, *maybeKind,
1393-
atomicMemOrdering, scopeStr.value());
1427+
atomicMemOrdering, *scopeStr);
13941428
} else {
1395-
atom = rewriter
1396-
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
1397-
atomicMemOrdering,
1398-
StringRef(scopeStr.value()))
1399-
.getResult();
1429+
if (useDppForPackedF16) {
1430+
// Determine on the runtime what atomic intrinsic to execute:
1431+
// packed or regular.
1432+
auto *packedBlock =
1433+
atomicBlock->splitBlock(rewriter.getInsertionPoint());
1434+
auto *regularBlock =
1435+
rewriter.createBlock(atomicBlock->getParent(),
1436+
std::next(Region::iterator(atomicBlock)));
1437+
rewriter.setInsertionPointToEnd(atomicBlock);
1438+
rewriter.create<LLVM::CondBrOp>(loc, enablePackedOpt, packedBlock,
1439+
regularBlock);
1440+
1441+
// Fill out the regular block, where we issue two atomic ops.
1442+
rewriter.setInsertionPointToEnd(regularBlock);
1443+
Value pairedOperand0 =
1444+
b.extract_element(valueElemTy, operand, b.i32_val(0));
1445+
Value pairedOperand1 =
1446+
b.extract_element(valueElemTy, operand, b.i32_val(1));
1447+
Value atomNonVec0 = rewriter.create<LLVM::AtomicRMWOp>(
1448+
loc, *maybeKind, rmwPtr, pairedOperand0, atomicMemOrdering,
1449+
*scopeStr);
1450+
Value atomNonVec1 = rewriter.create<LLVM::AtomicRMWOp>(
1451+
loc, *maybeKind, rightNeighbourPtr, pairedOperand1,
1452+
atomicMemOrdering, *scopeStr);
1453+
Value packedRes = b.undef(packF16Ty);
1454+
packedRes =
1455+
b.insert_element(packF16Ty, packedRes, atomNonVec0, b.i32_val(0));
1456+
packedRes =
1457+
b.insert_element(packF16Ty, packedRes, atomNonVec1, b.i32_val(1));
1458+
rewriter.create<LLVM::BrOp>(loc, packedRes, endBlock);
1459+
1460+
// Start to fill out the packed block.
1461+
rewriter.setInsertionPointToEnd(packedBlock);
1462+
}
1463+
atom = rewriter.create<LLVM::AtomicRMWOp>(
1464+
loc, *maybeKind, rmwPtr, operand, atomicMemOrdering, *scopeStr);
14001465
}
14011466

14021467
if (!tensorTy) {
@@ -1623,11 +1688,8 @@ struct AtomicRMWOpConversion
16231688
rewriter.setInsertionPointToEnd(leaderBlock);
16241689
// Utilize global atomic only by leader threads
16251690
rmwPtr = b.inttoptr(origPtrType, rmwPtr);
1626-
Value atom = rewriter
1627-
.create<LLVM::AtomicRMWOp>(loc, opKind, rmwPtr,
1628-
afterRedBlock->getArgument(0),
1629-
memOrdering, scope)
1630-
.getResult();
1691+
Value atom = rewriter.create<LLVM::AtomicRMWOp>(
1692+
loc, opKind, rmwPtr, afterRedBlock->getArgument(0), memOrdering, scope);
16311693
rewriter.create<LLVM::BrOp>(loc, atom, endBlock);
16321694
rewriter.setInsertionPointToStart(endBlock);
16331695

0 commit comments

Comments
 (0)