Skip to content

Commit 5014ca9

Browse files
knwngjataylo
authored andcommitted
[AMD] Support warp-level reduction with DPP (triton-lang#5019)
This commit adds support for warp-level reduction with DPP instructions, which can improve performance. See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ (cherry picked from commit 21119e3)
1 parent d99536d commit 5014ca9

File tree

8 files changed

+404
-40
lines changed

8 files changed

+404
-40
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,10 @@ def TT_ReduceOp: TT_Op<"reduce",
727727
llvm::SmallVector<RankedTensorType> getInputTypes();
728728
llvm::SmallVector<Type> getElementTypes();
729729
unsigned getNumOperands();
730+
731+
// Returns the CombineOp iff this ReduceOp's region contains only
732+
// one CombineOp other than the return, or nullptr if not applicable.
733+
::mlir::Operation *getSingleCombiner();
730734
}];
731735
}
732736

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
503503
return getElementTypesImpl(this->getOperands());
504504
}
505505

506+
::mlir::Operation *ReduceOp::getSingleCombiner() {
507+
if (getNumOperands() != 1 || getNumResults() != 1)
508+
return nullptr;
509+
Block *block = &(*getCombineOp().begin());
510+
Operation *yield = block->getTerminator();
511+
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
512+
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
513+
reduceOp->getNumResults() != 1)
514+
return nullptr;
515+
if (reduceOp->getOperand(0) != block->getArgument(0) ||
516+
reduceOp->getOperand(1) != block->getArgument(1))
517+
return nullptr;
518+
519+
return reduceOp;
520+
}
521+
506522
unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }
507523

508524
//-- ScanOp --

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
9494
tt.return
9595
}
9696
}
97+
98+
// -----
99+
100+
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
101+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
102+
// CHECK-LABEL: reduce_dpp_max
103+
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
104+
// CHECK: rocdl.update.dpp
105+
// CHECK-SAME: with 280, 15, 15, true : f32
106+
// CHECK-NEXT: llvm.intr.maxnum
107+
108+
// CHECK-NEXT: rocdl.update.dpp
109+
// CHECK-SAME: with 276, 15, 15, true : f32
110+
// CHECK-NEXT: llvm.intr.maxnum
111+
112+
// CHECK-NEXT: rocdl.update.dpp
113+
// CHECK-SAME: with 274, 15, 15, true : f32
114+
// CHECK-NEXT: llvm.intr.maxnum
115+
116+
// CHECK-NEXT: rocdl.update.dpp
117+
// CHECK-SAME: with 273, 15, 15, true : f32
118+
// CHECK-NEXT: llvm.intr.maxnum
119+
120+
// CHECK-NEXT: rocdl.update.dpp
121+
// CHECK-SAME: with 322, 10, 15, true : f32
122+
// CHECK-NEXT: llvm.intr.maxnum
123+
124+
// CHECK-NEXT: rocdl.update.dpp
125+
// CHECK-SAME: with 323, 15, 15, true : f32
126+
// CHECK-NEXT: llvm.intr.maxnum
127+
128+
// CHECK: llvm.amdgcn.readlane
129+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
130+
^bb0(%arg1: f32, %arg2: f32):
131+
%1 = arith.maxnumf %arg1, %arg2 : f32
132+
tt.reduce.return %1 : f32
133+
}) : (tensor<64xf32, #blocked3>) -> f32
134+
tt.return
135+
}
136+
}
137+
138+
// -----
139+
140+
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
141+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
142+
// CHECK-LABEL: reduce_xor_max
143+
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
144+
// CHECK: rocdl.ds_swizzle
145+
// CHECK: llvm.intr.maxnum
146+
147+
// CHECK: rocdl.update.dpp
148+
// CHECK-SAME: with 280, 15, 12, false : i32
149+
// CHECK: rocdl.update.dpp
150+
// CHECK-SAME: with 264, 15, 3, false : i32
151+
// CHECK: llvm.intr.maxnum
152+
153+
// CHECK: rocdl.update.dpp
154+
// CHECK-SAME: with 276, 15, 10, false : i32
155+
// CHECK: rocdl.update.dpp
156+
// CHECK-SAME: with 260, 15, 5, false : i32
157+
// CHECK: llvm.intr.maxnum
158+
159+
// CHECK: rocdl.update.dpp
160+
// CHECK-SAME: with 78, 15, 15, false : i32
161+
// CHECK: llvm.intr.maxnum
162+
163+
// CHECK: rocdl.update.dpp
164+
// CHECK-SAME: with 177, 15, 15, false : i32
165+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
166+
^bb0(%arg1: f32, %arg2: f32):
167+
%1 = arith.maxnumf %arg1, %arg2 : f32
168+
tt.reduce.return %1 : f32
169+
}) : (tensor<32xf32, #blocked4>) -> f32
170+
tt.return
171+
}
172+
}

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ enum class ISAFamily {
1919
// Deduces the corresponding ISA family for the given target gfx |arch|.
2020
ISAFamily deduceISAFamily(llvm::StringRef arch);
2121

22+
// Here is a partial definition of DppCtrl enums. For the complete definition,
23+
// please check:
24+
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939
25+
enum class DppCtrl : uint32_t {
26+
QUAD_PERM_FIRST = 0,
27+
ROW_SHL0 = 0x100,
28+
ROW_SHR0 = 0x110,
29+
BCAST15 = 0x142,
30+
BCAST31 = 0x143
31+
};
32+
2233
} // namespace mlir::triton::AMD
2334

2435
#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 179 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
66
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
77

8+
using mlir::triton::AMD::DppCtrl;
89
namespace mlir::triton::AMD {
910

1011
namespace {
@@ -103,34 +104,207 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
103104

104105
Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val,
105106
int i) const {
106-
return LLVM::AMD::shuffleXor(loc, rewriter, val, i);
107+
return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily());
107108
}
108109

109110
Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val,
110111
int i) const {
111-
return LLVM::AMD::shuffleUp(loc, rewriter, val, i);
112+
return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily());
112113
}
113114

114115
Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
115116
int i) const {
116-
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
117+
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
117118
}
118119

119120
Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
120121
Value i) const {
121-
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
122+
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
122123
}
123124

124125
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
125126
ModuleOp moduleOp, int axis) const {
126127
return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis);
127128
}
128129

130+
// Cast and sext values into specific-length int to meet the requirements of
131+
// instructions like UpdateDpp or readlane if necessary.
132+
static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc,
133+
Value &val, Type fromType,
134+
unsigned toBits) {
135+
unsigned originalBits = fromType.getIntOrFloatBitWidth();
136+
Type toType = fromType;
137+
138+
if (!fromType.isIntOrIndex()) {
139+
val = bitcast(val, int_ty(originalBits));
140+
toType = int_ty(originalBits);
141+
}
142+
143+
if (originalBits < toBits) {
144+
val = sext(int_ty(toBits), val);
145+
toType = int_ty(toBits);
146+
}
147+
148+
return toType;
149+
}
150+
151+
// Trunc the value to specific length and then cast it to given type if
152+
// necessary. This function is typically used in conjunction with
153+
// castToAndSExtInt.
154+
static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
155+
Value val, Type valType,
156+
unsigned fromBits) {
157+
unsigned originalBits = valType.getIntOrFloatBitWidth();
158+
Value toVal = val;
159+
160+
if (originalBits < fromBits) {
161+
toVal = trunc(int_ty(originalBits), toVal);
162+
}
163+
164+
if (!valType.isIntOrIndex()) {
165+
toVal = bitcast(toVal, valType);
166+
}
167+
168+
return toVal;
169+
}
170+
129171
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
130172
SmallVector<Value> &acc, triton::ReduceOp op,
131173
unsigned numLaneToReduce,
132174
unsigned interleave) const {
133-
return false;
175+
if (numLaneToReduce != 64)
176+
return false;
177+
178+
if (auto family = getISAFamily();
179+
family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) {
180+
return false;
181+
}
182+
183+
Operation *reduxOp = op.getSingleCombiner();
184+
if (!reduxOp)
185+
return false;
186+
187+
auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src,
188+
uint32_t dppCtrl, int rowMask,
189+
int bankMask) -> Value {
190+
// DPP has limited support for data types, so here we need to
191+
// cast non-integer types or integer types shorter than 32 bits
192+
// to int32, except for fp32.
193+
Type actualType = valType;
194+
if (!valType.isF32()) {
195+
actualType = castToAndSExtInt(rewriter, loc, src, valType, 32);
196+
}
197+
198+
Value dppResult =
199+
rewriter
200+
.create<ROCDL::DPPUpdateOp>(loc, actualType, src, src,
201+
rewriter.getI32IntegerAttr(dppCtrl),
202+
rewriter.getI32IntegerAttr(rowMask),
203+
rewriter.getI32IntegerAttr(bankMask),
204+
rewriter.getBoolAttr(true))
205+
.getRes();
206+
207+
if (!valType.isF32()) {
208+
src = truncAndCastFromInt(rewriter, loc, src, valType, 32);
209+
dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32);
210+
}
211+
212+
IRMapping mapping;
213+
mapping.map(reduxOp->getOperand(0), src);
214+
mapping.map(reduxOp->getOperand(1), dppResult);
215+
return rewriter.clone(*reduxOp, mapping)->getResult(0);
216+
};
217+
218+
for (int i = 0; i < acc.size(); i++) {
219+
Value buf;
220+
auto valType = acc[i].getType();
221+
222+
/*
223+
Here's the implementation of full-wavefront reduction using dpp.
224+
https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
225+
226+
Each step has a v_mov_dpp instruction following the redux op. In
227+
some cases, the lower-level compiler could merge them into single
228+
instruction. For example, v_mov_dpp + max => v_max_dpp.
229+
230+
For gfx9, we have 64 threads per warp. These 64 threads are arranged
231+
into 4 rows, with each row being 16 threads. Each 16 threads are arranged
232+
further into 4 banks, with each bank being 4 threads. Overall it's in a
233+
(row, bank, thread) structure. When shuffling, we use row/bank mask to
234+
indicate which row/bank to participate. Then modifier like row_shr and
235+
row_bcast means exact data movement schemes. In the following
236+
instructions, taking row 0 as an example:
237+
238+
Step 1: Right shift for 8 lanes.
239+
lane 8-15 = redux(lane 0-7, lane 8-15)
240+
241+
Step 2: Right shift for 4 lanes.
242+
lane 12-15 = redux(lane 8-11, lane 12-15)
243+
244+
Step 3: Right shift for 2 lanes.
245+
lane 14-15 = redux(lane 12-13, lane 14-15)
246+
247+
Step 4: Right shift for 1 lane.
248+
lane 15 = redux(lane 14, lane 15)
249+
250+
Step 5: Broadcast lane 15 of each row to all the lanes of its next row.
251+
lane 16-31 = redux(lane 15, lane 16-31)
252+
253+
Step 6: Broadcast lane 31 to lane 32-63.
254+
lane 32-63 = redux(lane 31, lane 32-63)
255+
256+
Now the reduction result is stored in lane 63.
257+
258+
Step 7: Read the reduction result from lane 63 and broadcast with
259+
readlane.
260+
*/
261+
262+
const int allRows = 0xf;
263+
const int allBanks = 0xf;
264+
265+
const uint32_t dppCtrlRowShr = static_cast<uint32_t>(DppCtrl::ROW_SHR0);
266+
267+
// row_shr:8
268+
buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr,
269+
allRows, allBanks);
270+
271+
// row_shr:4
272+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr,
273+
allRows, allBanks);
274+
275+
// row_shr:2
276+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr,
277+
allRows, allBanks);
278+
279+
// row_shr:1
280+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
281+
allRows, allBanks);
282+
283+
// row_bcast:15 row_mask:0xa
284+
buf = createDppReduxOpWithBoundCtrl(
285+
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
286+
287+
// row_bcast:31
288+
buf = createDppReduxOpWithBoundCtrl(valType, buf,
289+
static_cast<uint32_t>(DppCtrl::BCAST31),
290+
allRows, allBanks);
291+
292+
// Similarly, we need to cast data types for readlane instruction.
293+
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16);
294+
295+
// Get reduction result from lane 63
296+
std::string intrinsic = "llvm.amdgcn.readlane";
297+
Value result =
298+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType,
299+
ValueRange{buf, i32_val(63)})
300+
->getResult(0);
301+
302+
result = truncAndCastFromInt(rewriter, loc, result, valType, 16);
303+
304+
acc[i] = result;
305+
}
306+
307+
return true;
134308
}
135309

136310
void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,

0 commit comments

Comments
 (0)