Skip to content

Commit 3a17b5c

Browse files
[AMD] Use permlanex16 for shuffleXor on rdna (#7269)
On RDNA, permlanex16 works similar to DPP operations, but has more flexible lane selection. Each lane in the upper/lower block of 16 contiguous lanes can select an arbitrary lane in the other block to read from. With 4-bits per lane, we construct the identity mapping 0xfedcba9876543210 so that lane i in the upper 16 lanes reads data from lane i in the lower 16 lanes and vice versa. This does not require a round trip to LDS, as was necessary with the previously used ds_swizzle instruction. Co-authored-by: Paul Trojahn <[email protected]>
1 parent 8dec0ed commit 3a17b5c

File tree

7 files changed

+85
-45
lines changed

7 files changed

+85
-45
lines changed

test/Conversion/amd/tritongpu_to_llvm_rdna.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
2020
// CHECK-SAME: with 273, 15, 15, true : f32
2121
// CHECK-NEXT: llvm.intr.maxnum
2222

23-
// CHECK: llvm.amdgcn.permlanex16
23+
// CHECK: rocdl.permlanex16
2424
// CHECK: llvm.intr.maxnum
2525
// CHECK: rocdl.readlane
2626
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
@@ -31,3 +31,33 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
3131
tt.return
3232
}
3333
}
34+
35+
#linear = #ttg.linear<{register = [[16, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1]], warp = [], block = []}>
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
37+
// CHECK-LABEL: @reduce_linear_layout
38+
tt.func private @reduce_linear_layout(%arg0: tensor<32x2xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> {
39+
// This tensor has 64 elements with the last dimension across the lower and upper 16 lanes.
40+
// Therefore, we can reduce it with a 16 element butterfly shuffle.
41+
42+
// CHECK-DAG: [[result0:%.*]] = llvm.mlir.undef
43+
// CHECK-DAG: [[select_lo:%.*]] = llvm.mlir.constant(1985229328 : i32)
44+
// CHECK-DAG: [[select_hi:%.*]] = llvm.mlir.constant(-19088744 : i32)
45+
// CHECK-DAG: [[reg0:%.*]] = llvm.extractvalue %arg0[0]
46+
// CHECK-DAG: [[reg1:%.*]] = llvm.extractvalue %arg0[1]
47+
// CHECK: [[permlane0:%.*]] = rocdl.permlanex16 [[reg0]], [[reg0]], [[select_lo]], [[select_hi]], true, false
48+
// CHECK: [[sum0:%.*]] = llvm.add [[reg0]], [[permlane0]]
49+
// CHECK: [[permlane1:%.*]] = rocdl.permlanex16 [[reg1]], [[reg1]], [[select_lo]], [[select_hi]], true, false
50+
// CHECK: [[sum1:%.*]] = llvm.add [[reg1]], [[permlane1]]
51+
// CHECK: [[result1:%.*]] = llvm.insertvalue [[sum0]], [[result0]][0]
52+
// CHECK: [[result2:%.*]] = llvm.insertvalue [[sum1]], [[result1]][1]
53+
54+
%0 = "tt.reduce"(%arg0) ({
55+
^bb0(%arg1: i32, %arg2: i32):
56+
%1 = arith.addi %arg1, %arg2 : i32
57+
tt.reduce.return %1 : i32
58+
}) {axis = 1 : i32} : (tensor<32x2xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
59+
60+
// CHECK: llvm.return [[result2]]
61+
tt.return %0 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
62+
}
63+
}

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ ISAFamily deduceISAFamily(llvm::StringRef arch);
2323
// Retursn true if given architecture support V_DOT instruction.
2424
bool supportsVDot(llvm::StringRef arch);
2525

26+
bool isCDNA(ISAFamily isaFamily);
27+
28+
bool isRDNA(ISAFamily isaFamily);
29+
2630
// Here is a partial definition of DppCtrl enums. For the complete definition,
2731
// please check:
2832
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,7 @@ llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const {
6666
return llvm::AMDGPU::parseArchAMDGCN(arch);
6767
}
6868

69-
bool TargetInfo::isCDNA() const {
70-
switch (getISAFamily()) {
71-
case ISAFamily::CDNA1:
72-
case ISAFamily::CDNA2:
73-
case ISAFamily::CDNA3:
74-
case ISAFamily::CDNA4:
75-
return true;
76-
default:
77-
break;
78-
}
79-
80-
return false;
81-
}
82-
83-
bool TargetInfo::isRDNA() const {
84-
switch (getISAFamily()) {
85-
case ISAFamily::RDNA1:
86-
case ISAFamily::RDNA2:
87-
case ISAFamily::RDNA3:
88-
return true;
89-
default:
90-
break;
91-
}
92-
93-
return false;
94-
}
95-
96-
int TargetInfo::getWarpSize() const { return isCDNA() ? 64 : 32; }
69+
int TargetInfo::getWarpSize() const { return isCDNA(getISAFamily()) ? 64 : 32; }
9770

9871
int TargetInfo::getSharedMemorySize() const {
9972
int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64;
@@ -317,9 +290,9 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
317290
return true;
318291
if (numLaneToReduce != getWarpSize())
319292
return false;
320-
if (isCDNA() && getISAFamily() == ISAFamily::CDNA1)
293+
if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1)
321294
return false;
322-
if (isRDNA() && getISAFamily() != ISAFamily::RDNA3)
295+
if (isRDNA(getISAFamily()) && getISAFamily() != ISAFamily::RDNA3)
323296
return false;
324297

325298
Operation *reduxOp = op.getSingleCombiner();
@@ -420,7 +393,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
420393
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
421394
allRows, allBanks);
422395

423-
if (isCDNA()) {
396+
if (isCDNA(getISAFamily())) {
424397
// row_bcast:15 row_mask:0xa
425398
buf = createDppReduxOpWithBoundCtrl(
426399
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
@@ -433,12 +406,12 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
433406
// RDNA doesn't have broadcast dpp mode
434407
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 32);
435408

436-
Value permlaneResult =
437-
LLVM::createLLVMIntrinsicCallOp(
438-
rewriter, loc, "llvm.amdgcn.permlanex16", actualType,
439-
ValueRange{buf, buf, b.i32_val(-1), b.i32_val(-1), b.true_val(),
440-
b.false_val()})
441-
->getResult(0);
409+
// Lanes 0-15 read from lane 31 and lanes 16-31 read from lane 15.
410+
Value permlaneResult = rewriter
411+
.create<ROCDL::PermlaneX16Op>(
412+
loc, actualType, buf, buf, b.i32_val(-1),
413+
b.i32_val(-1), true, false)
414+
.getRes();
442415
buf = truncAndCastFromInt(rewriter, loc, buf, valType, 32);
443416
permlaneResult =
444417
truncAndCastFromInt(rewriter, loc, permlaneResult, valType, 32);

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
1515

1616
llvm::AMDGPU::GPUKind getGPUKind() const;
1717

18-
bool isCDNA() const;
19-
20-
bool isRDNA() const;
21-
2218
int getWarpSize() const;
2319

2420
int getSharedMemorySize() const;

third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,31 @@ bool supportsVDot(llvm::StringRef arch) {
4949
return false;
5050
}
5151

52+
bool isCDNA(ISAFamily isaFamily) {
53+
switch (isaFamily) {
54+
case ISAFamily::CDNA1:
55+
case ISAFamily::CDNA2:
56+
case ISAFamily::CDNA3:
57+
case ISAFamily::CDNA4:
58+
return true;
59+
default:
60+
break;
61+
}
62+
63+
return false;
64+
}
65+
66+
bool isRDNA(ISAFamily isaFamily) {
67+
switch (isaFamily) {
68+
case ISAFamily::RDNA1:
69+
case ISAFamily::RDNA2:
70+
case ISAFamily::RDNA3:
71+
return true;
72+
default:
73+
break;
74+
}
75+
76+
return false;
77+
}
78+
5279
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "Utility.h"
22
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
33
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
4+
#include "TritonAMDGPUToLLVM/TargetUtils.h"
45
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
56
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
67
#include "mlir/IR/PatternMatch.h"
@@ -137,8 +138,17 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
137138
Value lineId = b.xor_(threadId, stride);
138139
return bpermute(lineId);
139140
} else if (strideInt == 16) {
140-
Value offset = b.i32_val(0x401F);
141-
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
141+
if (isRDNA(isaFamily)) {
142+
// Lane i in the upper 16 lanes reads the value from lane i in the lower
143+
// 16 lanes and vice versa.
144+
Value select_lo = b.i32_val(0x76543210);
145+
Value select_hi = b.i32_val(0xfedcba98);
146+
return rewriter.create<ROCDL::PermlaneX16Op>(
147+
loc, valType, val, val, select_lo, select_hi, true, false);
148+
} else {
149+
Value offset = b.i32_val(0x401F);
150+
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
151+
}
142152
} else {
143153
if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3,
144154
ISAFamily::CDNA4, ISAFamily::RDNA3},

third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct TritonAMDGPUUpdateAsyncWaitCountPass
119119

120120
void runOnOperation() override {
121121
tt::AMD::TargetInfo targetInfo(archGenerationName);
122-
if (!targetInfo.isCDNA()) {
122+
if (!isCDNA(targetInfo.getISAFamily())) {
123123
return;
124124
}
125125

0 commit comments

Comments
 (0)