Skip to content

Commit a39389a

Browse files
[AMD] Add support for dpp instructions on RDNA (#6250)
Enables the dpp reduction path from CDNA on RDNA with small modifications to handle the missing BCAST15 mode. --------- Co-authored-by: paul <[email protected]>
1 parent 71e30cb commit a39389a

File tree

4 files changed

+109
-23
lines changed

4 files changed

+109
-23
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1100 --convert-builtin-func-to-llvm | FileCheck %s
2+
3+
#blocked3 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
5+
// CHECK-LABEL: reduce_dpp_max
6+
tt.func @reduce_dpp_max(%arg0: tensor<32xf32, #blocked3>) {
7+
// CHECK: rocdl.update.dpp
8+
// CHECK-SAME: with 280, 15, 15, true : f32
9+
// CHECK-NEXT: llvm.intr.maxnum
10+
11+
// CHECK-NEXT: rocdl.update.dpp
12+
// CHECK-SAME: with 276, 15, 15, true : f32
13+
// CHECK-NEXT: llvm.intr.maxnum
14+
15+
// CHECK-NEXT: rocdl.update.dpp
16+
// CHECK-SAME: with 274, 15, 15, true : f32
17+
// CHECK-NEXT: llvm.intr.maxnum
18+
19+
// CHECK-NEXT: rocdl.update.dpp
20+
// CHECK-SAME: with 273, 15, 15, true : f32
21+
// CHECK-NEXT: llvm.intr.maxnum
22+
23+
// CHECK: llvm.amdgcn.permlanex16
24+
// CHECK: llvm.intr.maxnum
25+
// CHECK: llvm.amdgcn.readlane
26+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
27+
^bb0(%arg1: f32, %arg2: f32):
28+
%1 = arith.maxnumf %arg1, %arg2 : f32
29+
tt.reduce.return %1 : f32
30+
}) : (tensor<32xf32, #blocked3>) -> f32
31+
tt.return
32+
}
33+
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,35 @@ llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const {
6565
return llvm::AMDGPU::parseArchAMDGCN(arch);
6666
}
6767

68+
bool TargetInfo::isCDNA() const {
69+
switch (getISAFamily()) {
70+
case ISAFamily::CDNA1:
71+
case ISAFamily::CDNA2:
72+
case ISAFamily::CDNA3:
73+
case ISAFamily::CDNA4:
74+
return true;
75+
default:
76+
break;
77+
}
78+
79+
return false;
80+
}
81+
82+
bool TargetInfo::isRDNA() const {
83+
switch (getISAFamily()) {
84+
case ISAFamily::RDNA1:
85+
case ISAFamily::RDNA2:
86+
case ISAFamily::RDNA3:
87+
return true;
88+
default:
89+
break;
90+
}
91+
92+
return false;
93+
}
94+
95+
int TargetInfo::getWarpSize() const { return isCDNA() ? 64 : 32; }
96+
6897
int TargetInfo::getSharedMemorySize() const {
6998
int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64;
7099
return kbytes * 1024;
@@ -200,14 +229,13 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
200229
unsigned numLaneToReduce,
201230
unsigned interleave) const {
202231
auto b = TritonLLVMOpBuilder(loc, rewriter);
203-
if (numLaneToReduce != 64)
204-
return false;
205232

206-
if (!llvm::is_contained(
207-
{ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4},
208-
getISAFamily())) {
233+
if (numLaneToReduce != getWarpSize())
234+
return false;
235+
if (isCDNA() && getISAFamily() == ISAFamily::CDNA1)
236+
return false;
237+
if (isRDNA() && getISAFamily() != ISAFamily::RDNA3)
209238
return false;
210-
}
211239

212240
Operation *reduxOp = op.getSingleCombiner();
213241
if (!reduxOp)
@@ -307,24 +335,43 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
307335
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
308336
allRows, allBanks);
309337

310-
// row_bcast:15 row_mask:0xa
311-
buf = createDppReduxOpWithBoundCtrl(
312-
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
338+
if (isCDNA()) {
339+
// row_bcast:15 row_mask:0xa
340+
buf = createDppReduxOpWithBoundCtrl(
341+
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
313342

314-
// row_bcast:31
315-
buf = createDppReduxOpWithBoundCtrl(valType, buf,
316-
static_cast<uint32_t>(DppCtrl::BCAST31),
317-
allRows, allBanks);
343+
// row_bcast:31
344+
buf = createDppReduxOpWithBoundCtrl(
345+
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST31), allRows,
346+
allBanks);
347+
} else {
348+
// RDNA doesn't have broadcast dpp mode
349+
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 32);
350+
351+
Value permlaneResult =
352+
LLVM::createLLVMIntrinsicCallOp(
353+
rewriter, loc, "llvm.amdgcn.permlanex16", actualType,
354+
ValueRange{buf, buf, b.i32_val(-1), b.i32_val(-1), b.true_val(),
355+
b.false_val()})
356+
->getResult(0);
357+
buf = truncAndCastFromInt(rewriter, loc, buf, valType, 32);
358+
permlaneResult =
359+
truncAndCastFromInt(rewriter, loc, permlaneResult, valType, 32);
360+
IRMapping mapping;
361+
mapping.map(reduxOp->getOperand(0), buf);
362+
mapping.map(reduxOp->getOperand(1), permlaneResult);
363+
buf = rewriter.clone(*reduxOp, mapping)->getResult(0);
364+
}
318365

319366
// Similarly, we need to cast data types for readlane instruction.
320367
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16);
321368

322-
// Get reduction result from lane 63
369+
// Get reduction result from lane 63/31
323370
std::string intrinsic = "llvm.amdgcn.readlane";
324-
Value result =
325-
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType,
326-
ValueRange{buf, b.i32_val(63)})
327-
->getResult(0);
371+
Value result = LLVM::createLLVMIntrinsicCallOp(
372+
rewriter, loc, intrinsic, actualType,
373+
ValueRange{buf, b.i32_val(isCDNA() ? 63 : 31)})
374+
->getResult(0);
328375

329376
result = truncAndCastFromInt(rewriter, loc, result, valType, 16);
330377

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
1515

1616
llvm::AMDGPU::GPUKind getGPUKind() const;
1717

18+
bool isCDNA() const;
19+
20+
bool isRDNA() const;
21+
22+
int getWarpSize() const;
23+
1824
int getSharedMemorySize() const;
1925

2026
bool supportMaximumMinimum() const override;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
139139
Value offset = b.i32_val(0x401F);
140140
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
141141
} else {
142-
if (!llvm::is_contained(
143-
{ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4},
144-
isaFamily)) {
145-
// DPP is only supported for CDNA2/CDNA3/CDNA4 right now, so we fallback
146-
// to ds_swizzle for other architectures.
142+
if (!llvm::is_contained({ISAFamily::CDNA2, ISAFamily::CDNA3,
143+
ISAFamily::CDNA4, ISAFamily::RDNA3},
144+
isaFamily)) {
145+
// DPP is only supported for CDNA2/CDNA3/CDNA4/RDNA3 right now, so we
146+
// fallback to ds_swizzle for other architectures.
147147
//
148148
// This map facilates the butterfly shuffle pattern for a stride less
149149
// than 16. The pattern stride is the key of the map.

0 commit comments

Comments
 (0)