Skip to content

Commit dab5917

Browse files
authored
[release/3.4] "[AMD] Always CheckPairs for packed fp16/bf16 atomic instructions (triton-lang#7326)" (triton-lang#7339)
When using packed atomic ops, there are a few conditions to meet. The PR: triton-lang#6258 to refactor these checking added a condition to skip checking the condition and it introduced a bug (it missed the case that even lane tid access addresses not 4-byte aligned), so we need to always do that check. This PR is to change to always run `checkpairs`. (cherry picked from commit 36b3473) Don't know if thing below is applicable to a cherry-pick, but I copied it from the original PR. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 4972f45 commit dab5917

File tree

4 files changed

+61
-45
lines changed

4 files changed

+61
-45
lines changed

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,36 @@ def kernel(X, val, NUM: tl.constexpr):
17301730
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
17311731

17321732

1733+
@pytest.mark.interpreter
1734+
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
1735+
for size in [2, 4, 8, 32, 64, 128]
1736+
for num_ctas in num_ctas_list
1737+
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
1738+
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
1739+
check_type_supported(dtype_x_str, device)
1740+
1741+
@triton.jit
1742+
def kernel(X, val, NUM: tl.constexpr):
1743+
off_x = tl.arange(0, 2)
1744+
off_y = tl.arange(0, NUM)
1745+
off_in = off_x[:, None] * NUM + off_y[None, :]
1746+
off_out = off_x[:, None] + off_y[None, :]
1747+
1748+
val = tl.load(val + off_in)
1749+
tl.atomic_add(X + off_out, val)
1750+
1751+
s = (2, size)
1752+
dtype = getattr(torch, dtype_x_str)
1753+
x = torch.zeros(s, dtype=dtype, device=device)
1754+
ref = torch.flatten(x)
1755+
val = torch.randn(s, dtype=dtype, device=device)
1756+
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
1757+
val = torch.flatten(val)
1758+
ref[0:size] = val[0:size]
1759+
ref[1:size + 1] += val[size:2 * size]
1760+
torch.testing.assert_close(ref, torch.flatten(x))
1761+
1762+
17331763
@pytest.mark.interpreter
17341764
@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str",
17351765
[(shape, idx_order, mask_step, num_ctas, dtype_x_str)

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.cpp

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ Value AtomicRMWEmitter::emitAtomicRMW(RewriterBase &rewriter, Value rmwPtr,
207207

208208
Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
209209
Value rmwPtr, Value valElem,
210-
Value rmwMask,
211-
bool checkPairs) const {
210+
Value rmwMask) const {
212211
auto loc = rmwPtr.getLoc();
213212
auto b = TritonLLVMOpBuilder(loc, rewriter);
214213
Value i64Ones = b.i64_val(~uint64_t(0));
@@ -231,44 +230,34 @@ Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
231230
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
232231
Value operand = b.bitcast(b.or_(packedVal, dppMoveRes), packF16Ty);
233232

234-
// If a runtime check is unnecessary (`checkPairs` is `false`),
235-
// `rightNeighbourPtr` is irrelevant.
236-
// Set the conditional value `enablePackedOpt` to `true` to enable DCE on the
237-
// runtime check branch.
238-
Value rightNeighbourPtr = rmwPtr;
239-
Value enablePackedOpt = b.true_val();
240-
if (checkPairs) {
241-
Value rightNeighbourAddr =
242-
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
243-
244-
// Packing optimization only supported if following conditions are true:
245-
// 1. address is aligned by 4 bytes
246-
// 2. right neighbour has adjacent address
247-
// 3. both threads are active
248-
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
249-
Value neighbourAddrAdjacent = b.icmp_eq(
250-
rightNeighbourAddr,
251-
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
252-
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
253-
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
254-
enablePackedOpt =
255-
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
256-
257-
// Enable only the even threads.
258-
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
259-
// If one of the threads is disabled, use the neighbour's addr.
260-
rightNeighbourAddr =
261-
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
262-
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
263-
264-
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
265-
266-
// Unpack results back
267-
rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
268-
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
269-
} else {
270-
rmwMask = b.and_(rmwMask, b.icmp_eq(isOddI32, b.i32_val(0)));
271-
}
233+
Value rightNeighbourAddr =
234+
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
235+
236+
// Packing optimization only supported if following conditions are true:
237+
// 1. address is aligned by 4 bytes
238+
// 2. right neighbour has adjacent address
239+
// 3. both threads are active
240+
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
241+
Value neighbourAddrAdjacent = b.icmp_eq(
242+
rightNeighbourAddr,
243+
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
244+
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
245+
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
246+
Value enablePackedOpt =
247+
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
248+
249+
// Enable only the even threads.
250+
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
251+
// If one of the threads is disabled, use the neighbour's addr.
252+
rightNeighbourAddr =
253+
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
254+
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
255+
256+
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
257+
258+
// Unpack results back
259+
Value rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
260+
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
272261

273262
Value undefVal = b.undef(packF16Ty);
274263
// Build blocks to bypass the atomic instruction for ~rmwMask.

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class AtomicRMWEmitter {
2121
bool enableIntraWaveReduce) const;
2222

2323
Value emitPairedAtomicForEvenTID(RewriterBase &rewriter, Value rmwPtr,
24-
Value valElem, Value rmwMask,
25-
bool checkPairs = true) const;
24+
Value valElem, Value rmwMask) const;
2625

2726
private:
2827
const mlir::triton::AMD::TargetInfo &targetInfo;

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,6 @@ struct AtomicRMWOpConversion
14891489
// TODO: support data types less than 32 bits
14901490
enableIntraWaveReduce &= valueElemTy.getIntOrFloatBitWidth() >= 32;
14911491

1492-
bool checkPairs = true;
14931492
if (tensorTy) {
14941493
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
14951494
unsigned availableVecSize = isF16Ty ? 2 : 1;
@@ -1505,7 +1504,6 @@ struct AtomicRMWOpConversion
15051504
auto threadOrder = getThreadOrder(tensorTy);
15061505
unsigned contigWithinLanes =
15071506
axisAnalysisPass.getAxisInfo(ptr)->getContiguity(threadOrder.front());
1508-
checkPairs = !(contigWithinLanes > 1 && contigWithinLanes % 2 == 0);
15091507
enableIntraWaveReduce &= contigWithinLanes == 1;
15101508
}
15111509

@@ -1530,7 +1528,7 @@ struct AtomicRMWOpConversion
15301528
Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask;
15311529
if (applyPackingF16) {
15321530
resultVals[i] = emitter.emitPairedAtomicForEvenTID(
1533-
rewriter, ptrElements[i], valElements[i], rmwMask, checkPairs);
1531+
rewriter, ptrElements[i], valElements[i], rmwMask);
15341532
} else {
15351533
Value valElement;
15361534
if (vec == 1) {

0 commit comments

Comments
 (0)