Skip to content

Commit 36b3473

Browse files
authored
[AMD] Always CheckPairs for packed fp16/bf16 atomic instructions (#7326)
When using packed atomic ops, there are a few conditions to meet. The PR: triton-lang/triton#6258 to refactor these checking added a condition to skip checking the condition and it introduced a bug (it missed the case that even lane tid access addresses not 4-byte aligned), so we need to always do that check. This PR is to change to always run `checkpairs`.
1 parent 23b8a7d commit 36b3473

File tree

4 files changed

+61
-45
lines changed

4 files changed

+61
-45
lines changed

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,36 @@ def kernel(X, val, NUM: tl.constexpr):
17311731
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
17321732

17331733

1734+
@pytest.mark.interpreter
1735+
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
1736+
for size in [2, 4, 8, 32, 64, 128]
1737+
for num_ctas in num_ctas_list
1738+
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
1739+
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
1740+
check_type_supported(dtype_x_str, device)
1741+
1742+
@triton.jit
1743+
def kernel(X, val, NUM: tl.constexpr):
1744+
off_x = tl.arange(0, 2)
1745+
off_y = tl.arange(0, NUM)
1746+
off_in = off_x[:, None] * NUM + off_y[None, :]
1747+
off_out = off_x[:, None] + off_y[None, :]
1748+
1749+
val = tl.load(val + off_in)
1750+
tl.atomic_add(X + off_out, val)
1751+
1752+
s = (2, size)
1753+
dtype = getattr(torch, dtype_x_str)
1754+
x = torch.zeros(s, dtype=dtype, device=device)
1755+
ref = torch.flatten(x)
1756+
val = torch.randn(s, dtype=dtype, device=device)
1757+
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
1758+
val = torch.flatten(val)
1759+
ref[0:size] = val[0:size]
1760+
ref[1:size + 1] += val[size:2 * size]
1761+
torch.testing.assert_close(ref, torch.flatten(x))
1762+
1763+
17341764
@pytest.mark.interpreter
17351765
@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str",
17361766
[(shape, idx_order, mask_step, num_ctas, dtype_x_str)

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.cpp

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ Value AtomicRMWEmitter::emitAtomicRMW(RewriterBase &rewriter, Value rmwPtr,
207207

208208
Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
209209
Value rmwPtr, Value valElem,
210-
Value rmwMask,
211-
bool checkPairs) const {
210+
Value rmwMask) const {
212211
auto loc = rmwPtr.getLoc();
213212
auto b = TritonLLVMOpBuilder(loc, rewriter);
214213
Value i64Ones = b.i64_val(~uint64_t(0));
@@ -231,44 +230,34 @@ Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
231230
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
232231
Value operand = b.bitcast(b.or_(packedVal, dppMoveRes), packF16Ty);
233232

234-
// If a runtime check is unnecessary (`checkPairs` is `false`),
235-
// `rightNeighbourPtr` is irrelevant.
236-
// Set the conditional value `enablePackedOpt` to `true` to enable DCE on the
237-
// runtime check branch.
238-
Value rightNeighbourPtr = rmwPtr;
239-
Value enablePackedOpt = b.true_val();
240-
if (checkPairs) {
241-
Value rightNeighbourAddr =
242-
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
243-
244-
// Packing optimization only supported if following conditions are true:
245-
// 1. address is aligned by 4 bytes
246-
// 2. right neighbour has adjacent address
247-
// 3. both threads are active
248-
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
249-
Value neighbourAddrAdjacent = b.icmp_eq(
250-
rightNeighbourAddr,
251-
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
252-
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
253-
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
254-
enablePackedOpt =
255-
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
256-
257-
// Enable only the even threads.
258-
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
259-
// If one of the threads is disabled, use the neighbour's addr.
260-
rightNeighbourAddr =
261-
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
262-
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
263-
264-
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
265-
266-
// Unpack results back
267-
rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
268-
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
269-
} else {
270-
rmwMask = b.and_(rmwMask, b.icmp_eq(isOddI32, b.i32_val(0)));
271-
}
233+
Value rightNeighbourAddr =
234+
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
235+
236+
// Packing optimization only supported if following conditions are true:
237+
// 1. address is aligned by 4 bytes
238+
// 2. right neighbour has adjacent address
239+
// 3. both threads are active
240+
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
241+
Value neighbourAddrAdjacent = b.icmp_eq(
242+
rightNeighbourAddr,
243+
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
244+
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
245+
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
246+
Value enablePackedOpt =
247+
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
248+
249+
// Enable only the even threads.
250+
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
251+
// If one of the threads is disabled, use the neighbour's addr.
252+
rightNeighbourAddr =
253+
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
254+
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
255+
256+
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
257+
258+
// Unpack results back
259+
Value rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
260+
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
272261

273262
Value undefVal = b.undef(packF16Ty);
274263
// Build blocks to bypass the atomic instruction for ~rmwMask.

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class AtomicRMWEmitter {
2121
bool enableIntraWaveReduce) const;
2222

2323
Value emitPairedAtomicForEvenTID(RewriterBase &rewriter, Value rmwPtr,
24-
Value valElem, Value rmwMask,
25-
bool checkPairs = true) const;
24+
Value valElem, Value rmwMask) const;
2625

2726
private:
2827
const mlir::triton::AMD::TargetInfo &targetInfo;

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,6 @@ struct AtomicRMWOpConversion
14891489
// TODO: support data types less than 32 bits
14901490
enableIntraWaveReduce &= valueElemTy.getIntOrFloatBitWidth() >= 32;
14911491

1492-
bool checkPairs = true;
14931492
if (tensorTy) {
14941493
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
14951494
unsigned availableVecSize = isF16Ty ? 2 : 1;
@@ -1505,7 +1504,6 @@ struct AtomicRMWOpConversion
15051504
auto threadOrder = getThreadOrder(tensorTy);
15061505
unsigned contigWithinLanes =
15071506
axisAnalysisPass.getAxisInfo(ptr)->getContiguity(threadOrder.front());
1508-
checkPairs = !(contigWithinLanes > 1 && contigWithinLanes % 2 == 0);
15091507
enableIntraWaveReduce &= contigWithinLanes == 1;
15101508
}
15111509

@@ -1530,7 +1528,7 @@ struct AtomicRMWOpConversion
15301528
Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask;
15311529
if (applyPackingF16) {
15321530
resultVals[i] = emitter.emitPairedAtomicForEvenTID(
1533-
rewriter, ptrElements[i], valElements[i], rmwMask, checkPairs);
1531+
rewriter, ptrElements[i], valElements[i], rmwMask);
15341532
} else {
15351533
Value valElement;
15361534
if (vec == 1) {

0 commit comments

Comments
 (0)