Skip to content

Commit 8cea3fe

Browse files
authored
Merge OpenAI Triton commit 36b3473 (#4614)
This PR change the Triton base from 272188c to 36b3473 (Jun 26). Pass rate: 97.14% Please do not squash and merge this PR.
2 parents 9752a56 + da1dad5 commit 8cea3fe

File tree

13 files changed

+152
-92
lines changed

13 files changed

+152
-92
lines changed

include/triton/Analysis/Membar.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class MembarOrFenceAnalysis {
119119
explicit MembarOrFenceAnalysis(Allocation *allocation, MembarFilterFn filter)
120120
: allocation(allocation), filter(filter) {}
121121

122+
virtual ~MembarOrFenceAnalysis() = default;
123+
122124
/// Runs the membar analysis to the given operation, inserts a barrier if
123125
/// necessary.
124126
void run(FuncBlockInfoMapT &funcBlockInfoMap);
@@ -160,6 +162,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
160162
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
161163
: MembarOrFenceAnalysis(allocation, filter) {}
162164

165+
~MembarAnalysis() override = default;
166+
163167
private:
164168
/// Updates the BlockInfo operation based on the operation.
165169
virtual void update(Operation *operation, BlockInfo *blockInfo,

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
5151
struct TMemAllocation {
5252
TMemAllocation(int numCols, int numRows)
5353
: numCols(numCols), numRows(numRows) {}
54-
int numRows;
5554
int numCols;
55+
int numRows;
5656
};
5757

5858
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,36 @@ def kernel(X, val, NUM: tl.constexpr):
17511751
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
17521752

17531753

1754+
@pytest.mark.interpreter
1755+
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
1756+
for size in [2, 4, 8, 32, 64, 128]
1757+
for num_ctas in num_ctas_list
1758+
for dtype_x_str in ['float16', 'float32']])
1759+
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
1760+
check_type_supported(dtype_x_str, device)
1761+
1762+
@triton.jit
1763+
def kernel(X, val, NUM: tl.constexpr):
1764+
off_x = tl.arange(0, 2)
1765+
off_y = tl.arange(0, NUM)
1766+
off_in = off_x[:, None] * NUM + off_y[None, :]
1767+
off_out = off_x[:, None] + off_y[None, :]
1768+
1769+
val = tl.load(val + off_in)
1770+
tl.atomic_add(X + off_out, val)
1771+
1772+
s = (2, size)
1773+
dtype = getattr(torch, dtype_x_str)
1774+
x = torch.zeros(s, dtype=dtype, device=device)
1775+
ref = torch.flatten(x)
1776+
val = torch.randn(s, dtype=dtype, device=device)
1777+
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
1778+
val = torch.flatten(val)
1779+
ref[0:size] = val[0:size]
1780+
ref[1:size + 1] += val[size:2 * size]
1781+
torch.testing.assert_close(ref, torch.flatten(x))
1782+
1783+
17541784
@pytest.mark.interpreter
17551785
@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str",
17561786
[(shape, idx_order, mask_step, num_ctas, dtype_x_str)

test/Conversion/amd/tritongpu_to_llvm_rdna.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
2020
// CHECK-SAME: with 273, 15, 15, true : f32
2121
// CHECK-NEXT: llvm.intr.maxnum
2222

23-
// CHECK: llvm.amdgcn.permlanex16
23+
// CHECK: rocdl.permlanex16
2424
// CHECK: llvm.intr.maxnum
2525
// CHECK: rocdl.readlane
2626
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
@@ -31,3 +31,33 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
3131
tt.return
3232
}
3333
}
34+
35+
#linear = #ttg.linear<{register = [[16, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1]], warp = [], block = []}>
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
37+
// CHECK-LABEL: @reduce_linear_layout
38+
tt.func private @reduce_linear_layout(%arg0: tensor<32x2xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> {
39+
// This tensor has 64 elements with the last dimension across the lower and upper 16 lanes.
40+
// Therefore, we can reduce it with a 16 element butterfly shuffle.
41+
42+
// CHECK-DAG: [[result0:%.*]] = llvm.mlir.undef
43+
// CHECK-DAG: [[select_lo:%.*]] = llvm.mlir.constant(1985229328 : i32)
44+
// CHECK-DAG: [[select_hi:%.*]] = llvm.mlir.constant(-19088744 : i32)
45+
// CHECK-DAG: [[reg0:%.*]] = llvm.extractvalue %arg0[0]
46+
// CHECK-DAG: [[reg1:%.*]] = llvm.extractvalue %arg0[1]
47+
// CHECK: [[permlane0:%.*]] = rocdl.permlanex16 [[reg0]], [[reg0]], [[select_lo]], [[select_hi]], true, false
48+
// CHECK: [[sum0:%.*]] = llvm.add [[reg0]], [[permlane0]]
49+
// CHECK: [[permlane1:%.*]] = rocdl.permlanex16 [[reg1]], [[reg1]], [[select_lo]], [[select_hi]], true, false
50+
// CHECK: [[sum1:%.*]] = llvm.add [[reg1]], [[permlane1]]
51+
// CHECK: [[result1:%.*]] = llvm.insertvalue [[sum0]], [[result0]][0]
52+
// CHECK: [[result2:%.*]] = llvm.insertvalue [[sum1]], [[result1]][1]
53+
54+
%0 = "tt.reduce"(%arg0) ({
55+
^bb0(%arg1: i32, %arg2: i32):
56+
%1 = arith.addi %arg1, %arg2 : i32
57+
tt.reduce.return %1 : i32
58+
}) {axis = 1 : i32} : (tensor<32x2xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
59+
60+
// CHECK: llvm.return [[result2]]
61+
tt.return %0 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
62+
}
63+
}

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ ISAFamily deduceISAFamily(llvm::StringRef arch);
2323
// Retursn true if given architecture support V_DOT instruction.
2424
bool supportsVDot(llvm::StringRef arch);
2525

26+
bool isCDNA(ISAFamily isaFamily);
27+
28+
bool isRDNA(ISAFamily isaFamily);
29+
2630
// Here is a partial definition of DppCtrl enums. For the complete definition,
2731
// please check:
2832
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.cpp

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ Value AtomicRMWEmitter::emitAtomicRMW(RewriterBase &rewriter, Value rmwPtr,
207207

208208
Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
209209
Value rmwPtr, Value valElem,
210-
Value rmwMask,
211-
bool checkPairs) const {
210+
Value rmwMask) const {
212211
auto loc = rmwPtr.getLoc();
213212
auto b = TritonLLVMOpBuilder(loc, rewriter);
214213
Value i64Ones = b.i64_val(~uint64_t(0));
@@ -231,44 +230,34 @@ Value AtomicRMWEmitter::emitPairedAtomicForEvenTID(RewriterBase &rewriter,
231230
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
232231
Value operand = b.bitcast(b.or_(packedVal, dppMoveRes), packF16Ty);
233232

234-
// If a runtime check is unnecessary (`checkPairs` is `false`),
235-
// `rightNeighbourPtr` is irrelevant.
236-
// Set the conditional value `enablePackedOpt` to `true` to enable DCE on the
237-
// runtime check branch.
238-
Value rightNeighbourPtr = rmwPtr;
239-
Value enablePackedOpt = b.true_val();
240-
if (checkPairs) {
241-
Value rightNeighbourAddr =
242-
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
243-
244-
// Packing optimization only supported if following conditions are true:
245-
// 1. address is aligned by 4 bytes
246-
// 2. right neighbour has adjacent address
247-
// 3. both threads are active
248-
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
249-
Value neighbourAddrAdjacent = b.icmp_eq(
250-
rightNeighbourAddr,
251-
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
252-
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
253-
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
254-
enablePackedOpt =
255-
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
256-
257-
// Enable only the even threads.
258-
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
259-
// If one of the threads is disabled, use the neighbour's addr.
260-
rightNeighbourAddr =
261-
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
262-
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
263-
264-
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
265-
266-
// Unpack results back
267-
rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
268-
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
269-
} else {
270-
rmwMask = b.and_(rmwMask, b.icmp_eq(isOddI32, b.i32_val(0)));
271-
}
233+
Value rightNeighbourAddr =
234+
genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr);
235+
236+
// Packing optimization only supported if following conditions are true:
237+
// 1. address is aligned by 4 bytes
238+
// 2. right neighbour has adjacent address
239+
// 3. both threads are active
240+
Value isAligned = b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0));
241+
Value neighbourAddrAdjacent = b.icmp_eq(
242+
rightNeighbourAddr,
243+
b.add(castedAddr, b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8)));
244+
Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr);
245+
Value bothEnabled = b.and_(neighbourEnabled, rmwMask);
246+
Value enablePackedOpt =
247+
b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent);
248+
249+
// Enable only the even threads.
250+
Value anyEnabled = b.or_(neighbourEnabled, rmwMask);
251+
// If one of the threads is disabled, use the neighbour's addr.
252+
rightNeighbourAddr =
253+
b.select(neighbourEnabled, rightNeighbourAddr, castedAddr);
254+
castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr);
255+
256+
rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0)));
257+
258+
// Unpack results back
259+
Value rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr);
260+
rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr);
272261

273262
Value undefVal = b.undef(packF16Ty);
274263
// Build blocks to bypass the atomic instruction for ~rmwMask.

third_party/amd/lib/TritonAMDGPUToLLVM/AtomicRMWOpsEmitter.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class AtomicRMWEmitter {
2121
bool enableIntraWaveReduce) const;
2222

2323
Value emitPairedAtomicForEvenTID(RewriterBase &rewriter, Value rmwPtr,
24-
Value valElem, Value rmwMask,
25-
bool checkPairs = true) const;
24+
Value valElem, Value rmwMask) const;
2625

2726
private:
2827
const mlir::triton::AMD::TargetInfo &targetInfo;

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,6 @@ struct AtomicRMWOpConversion
14891489
// TODO: support data types less than 32 bits
14901490
enableIntraWaveReduce &= valueElemTy.getIntOrFloatBitWidth() >= 32;
14911491

1492-
bool checkPairs = true;
14931492
if (tensorTy) {
14941493
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
14951494
unsigned availableVecSize = isF16Ty ? 2 : 1;
@@ -1505,7 +1504,6 @@ struct AtomicRMWOpConversion
15051504
auto threadOrder = getThreadOrder(tensorTy);
15061505
unsigned contigWithinLanes =
15071506
axisAnalysisPass.getAxisInfo(ptr)->getContiguity(threadOrder.front());
1508-
checkPairs = !(contigWithinLanes > 1 && contigWithinLanes % 2 == 0);
15091507
enableIntraWaveReduce &= contigWithinLanes == 1;
15101508
}
15111509

@@ -1530,7 +1528,7 @@ struct AtomicRMWOpConversion
15301528
Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask;
15311529
if (applyPackingF16) {
15321530
resultVals[i] = emitter.emitPairedAtomicForEvenTID(
1533-
rewriter, ptrElements[i], valElements[i], rmwMask, checkPairs);
1531+
rewriter, ptrElements[i], valElements[i], rmwMask);
15341532
} else {
15351533
Value valElement;
15361534
if (vec == 1) {

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,7 @@ llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const {
6666
return llvm::AMDGPU::parseArchAMDGCN(arch);
6767
}
6868

69-
bool TargetInfo::isCDNA() const {
70-
switch (getISAFamily()) {
71-
case ISAFamily::CDNA1:
72-
case ISAFamily::CDNA2:
73-
case ISAFamily::CDNA3:
74-
case ISAFamily::CDNA4:
75-
return true;
76-
default:
77-
break;
78-
}
79-
80-
return false;
81-
}
82-
83-
bool TargetInfo::isRDNA() const {
84-
switch (getISAFamily()) {
85-
case ISAFamily::RDNA1:
86-
case ISAFamily::RDNA2:
87-
case ISAFamily::RDNA3:
88-
return true;
89-
default:
90-
break;
91-
}
92-
93-
return false;
94-
}
95-
96-
int TargetInfo::getWarpSize() const { return isCDNA() ? 64 : 32; }
69+
int TargetInfo::getWarpSize() const { return isCDNA(getISAFamily()) ? 64 : 32; }
9770

9871
int TargetInfo::getSharedMemorySize() const {
9972
int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64;
@@ -312,14 +285,14 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
312285
unsigned interleave) const {
313286
auto b = TritonLLVMOpBuilder(loc, rewriter);
314287

315-
if (isCDNA() && getISAFamily() == ISAFamily::CDNA4 &&
288+
if (getISAFamily() == ISAFamily::CDNA4 &&
316289
warpReduceSwap16or32(rewriter, loc, acc, op, numLaneToReduce, interleave))
317290
return true;
318291
if (numLaneToReduce != getWarpSize())
319292
return false;
320-
if (isCDNA() && getISAFamily() == ISAFamily::CDNA1)
293+
if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1)
321294
return false;
322-
if (isRDNA() && getISAFamily() != ISAFamily::RDNA3)
295+
if (isRDNA(getISAFamily()) && getISAFamily() != ISAFamily::RDNA3)
323296
return false;
324297

325298
Operation *reduxOp = op.getSingleCombiner();
@@ -420,7 +393,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
420393
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
421394
allRows, allBanks);
422395

423-
if (isCDNA()) {
396+
if (isCDNA(getISAFamily())) {
424397
// row_bcast:15 row_mask:0xa
425398
buf = createDppReduxOpWithBoundCtrl(
426399
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
@@ -433,12 +406,12 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
433406
// RDNA doesn't have broadcast dpp mode
434407
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 32);
435408

436-
Value permlaneResult =
437-
LLVM::createLLVMIntrinsicCallOp(
438-
rewriter, loc, "llvm.amdgcn.permlanex16", actualType,
439-
ValueRange{buf, buf, b.i32_val(-1), b.i32_val(-1), b.true_val(),
440-
b.false_val()})
441-
->getResult(0);
409+
// Lanes 0-15 read from lane 31 and lanes 16-31 read from lane 15.
410+
Value permlaneResult = rewriter
411+
.create<ROCDL::PermlaneX16Op>(
412+
loc, actualType, buf, buf, b.i32_val(-1),
413+
b.i32_val(-1), true, false)
414+
.getRes();
442415
buf = truncAndCastFromInt(rewriter, loc, buf, valType, 32);
443416
permlaneResult =
444417
truncAndCastFromInt(rewriter, loc, permlaneResult, valType, 32);

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
1515

1616
llvm::AMDGPU::GPUKind getGPUKind() const;
1717

18-
bool isCDNA() const;
19-
20-
bool isRDNA() const;
21-
2218
int getWarpSize() const;
2319

2420
int getSharedMemorySize() const;

0 commit comments

Comments
 (0)