Skip to content

Commit 21119e3

Browse files
authored
[AMD] Support warp-level reduction with DPP (#5019)
This commit adds support for warp-level reduction with DPP instructions, which can improve performance. See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1 parent f737843 commit 21119e3

File tree

8 files changed

+404
-40
lines changed

8 files changed

+404
-40
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,10 @@ def TT_ReduceOp: TT_Op<"reduce",
731731
llvm::SmallVector<RankedTensorType> getInputTypes();
732732
llvm::SmallVector<Type> getElementTypes();
733733
unsigned getNumOperands();
734+
735+
// Returns the CombineOp iff this ReduceOp's region contains only
736+
// one CombineOp other than the return, or nullptr if not applicable.
737+
::mlir::Operation *getSingleCombiner();
734738
}];
735739
}
736740

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
503503
return getElementTypesImpl(this->getOperands());
504504
}
505505

506+
::mlir::Operation *ReduceOp::getSingleCombiner() {
507+
if (getNumOperands() != 1 || getNumResults() != 1)
508+
return nullptr;
509+
Block *block = &(*getCombineOp().begin());
510+
Operation *yield = block->getTerminator();
511+
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
512+
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
513+
reduceOp->getNumResults() != 1)
514+
return nullptr;
515+
if (reduceOp->getOperand(0) != block->getArgument(0) ||
516+
reduceOp->getOperand(1) != block->getArgument(1))
517+
return nullptr;
518+
519+
return reduceOp;
520+
}
521+
506522
unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }
507523

508524
//-- ScanOp --

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
132132
tt.return
133133
}
134134
}
135+
136+
// -----
137+
138+
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
139+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
140+
// CHECK-LABEL: reduce_dpp_max
141+
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
142+
// CHECK: rocdl.update.dpp
143+
// CHECK-SAME: with 280, 15, 15, true : f32
144+
// CHECK-NEXT: llvm.intr.maxnum
145+
146+
// CHECK-NEXT: rocdl.update.dpp
147+
// CHECK-SAME: with 276, 15, 15, true : f32
148+
// CHECK-NEXT: llvm.intr.maxnum
149+
150+
// CHECK-NEXT: rocdl.update.dpp
151+
// CHECK-SAME: with 274, 15, 15, true : f32
152+
// CHECK-NEXT: llvm.intr.maxnum
153+
154+
// CHECK-NEXT: rocdl.update.dpp
155+
// CHECK-SAME: with 273, 15, 15, true : f32
156+
// CHECK-NEXT: llvm.intr.maxnum
157+
158+
// CHECK-NEXT: rocdl.update.dpp
159+
// CHECK-SAME: with 322, 10, 15, true : f32
160+
// CHECK-NEXT: llvm.intr.maxnum
161+
162+
// CHECK-NEXT: rocdl.update.dpp
163+
// CHECK-SAME: with 323, 15, 15, true : f32
164+
// CHECK-NEXT: llvm.intr.maxnum
165+
166+
// CHECK: llvm.amdgcn.readlane
167+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
168+
^bb0(%arg1: f32, %arg2: f32):
169+
%1 = arith.maxnumf %arg1, %arg2 : f32
170+
tt.reduce.return %1 : f32
171+
}) : (tensor<64xf32, #blocked3>) -> f32
172+
tt.return
173+
}
174+
}
175+
176+
// -----
177+
178+
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
179+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
180+
// CHECK-LABEL: reduce_xor_max
181+
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
182+
// CHECK: rocdl.ds_swizzle
183+
// CHECK: llvm.intr.maxnum
184+
185+
// CHECK: rocdl.update.dpp
186+
// CHECK-SAME: with 280, 15, 12, false : i32
187+
// CHECK: rocdl.update.dpp
188+
// CHECK-SAME: with 264, 15, 3, false : i32
189+
// CHECK: llvm.intr.maxnum
190+
191+
// CHECK: rocdl.update.dpp
192+
// CHECK-SAME: with 276, 15, 10, false : i32
193+
// CHECK: rocdl.update.dpp
194+
// CHECK-SAME: with 260, 15, 5, false : i32
195+
// CHECK: llvm.intr.maxnum
196+
197+
// CHECK: rocdl.update.dpp
198+
// CHECK-SAME: with 78, 15, 15, false : i32
199+
// CHECK: llvm.intr.maxnum
200+
201+
// CHECK: rocdl.update.dpp
202+
// CHECK-SAME: with 177, 15, 15, false : i32
203+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
204+
^bb0(%arg1: f32, %arg2: f32):
205+
%1 = arith.maxnumf %arg1, %arg2 : f32
206+
tt.reduce.return %1 : f32
207+
}) : (tensor<32xf32, #blocked4>) -> f32
208+
tt.return
209+
}
210+
}

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ enum class ISAFamily {
1919
// Deduces the corresponding ISA family for the given target gfx |arch|.
2020
ISAFamily deduceISAFamily(llvm::StringRef arch);
2121

22+
// Here is a partial definition of DppCtrl enums. For the complete definition,
23+
// please check:
24+
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939
25+
enum class DppCtrl : uint32_t {
26+
QUAD_PERM_FIRST = 0,
27+
ROW_SHL0 = 0x100,
28+
ROW_SHR0 = 0x110,
29+
BCAST15 = 0x142,
30+
BCAST31 = 0x143
31+
};
32+
2233
} // namespace mlir::triton::AMD
2334

2435
#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 179 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
66
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
77

8+
using mlir::triton::AMD::DppCtrl;
89
namespace mlir::triton::AMD {
910

1011
namespace {
@@ -103,34 +104,207 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
103104

104105
Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val,
105106
int i) const {
106-
return LLVM::AMD::shuffleXor(loc, rewriter, val, i);
107+
return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily());
107108
}
108109

109110
Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val,
110111
int i) const {
111-
return LLVM::AMD::shuffleUp(loc, rewriter, val, i);
112+
return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily());
112113
}
113114

114115
Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
115116
int i) const {
116-
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
117+
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
117118
}
118119

119120
Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
120121
Value i) const {
121-
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i);
122+
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
122123
}
123124

124125
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
125126
ModuleOp moduleOp, int axis) const {
126127
return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis);
127128
}
128129

130+
// Cast and sext values into specific-length int to meet the requirements of
131+
// instructions like UpdateDpp or readlane if necessary.
132+
static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc,
133+
Value &val, Type fromType,
134+
unsigned toBits) {
135+
unsigned originalBits = fromType.getIntOrFloatBitWidth();
136+
Type toType = fromType;
137+
138+
if (!fromType.isIntOrIndex()) {
139+
val = bitcast(val, int_ty(originalBits));
140+
toType = int_ty(originalBits);
141+
}
142+
143+
if (originalBits < toBits) {
144+
val = sext(int_ty(toBits), val);
145+
toType = int_ty(toBits);
146+
}
147+
148+
return toType;
149+
}
150+
151+
// Trunc the value to specific length and then cast it to given type if
152+
// necessary. This function is typically used in conjunction with
153+
// castToAndSExtInt.
154+
static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
155+
Value val, Type valType,
156+
unsigned fromBits) {
157+
unsigned originalBits = valType.getIntOrFloatBitWidth();
158+
Value toVal = val;
159+
160+
if (originalBits < fromBits) {
161+
toVal = trunc(int_ty(originalBits), toVal);
162+
}
163+
164+
if (!valType.isIntOrIndex()) {
165+
toVal = bitcast(toVal, valType);
166+
}
167+
168+
return toVal;
169+
}
170+
129171
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
130172
SmallVector<Value> &acc, triton::ReduceOp op,
131173
unsigned numLaneToReduce,
132174
unsigned interleave) const {
133-
return false;
175+
if (numLaneToReduce != 64)
176+
return false;
177+
178+
if (auto family = getISAFamily();
179+
family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) {
180+
return false;
181+
}
182+
183+
Operation *reduxOp = op.getSingleCombiner();
184+
if (!reduxOp)
185+
return false;
186+
187+
auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src,
188+
uint32_t dppCtrl, int rowMask,
189+
int bankMask) -> Value {
190+
// DPP has limited support for data types, so here we need to
191+
// cast non-integer types or integer types shorter than 32 bits
192+
// to int32, except for fp32.
193+
Type actualType = valType;
194+
if (!valType.isF32()) {
195+
actualType = castToAndSExtInt(rewriter, loc, src, valType, 32);
196+
}
197+
198+
Value dppResult =
199+
rewriter
200+
.create<ROCDL::DPPUpdateOp>(loc, actualType, src, src,
201+
rewriter.getI32IntegerAttr(dppCtrl),
202+
rewriter.getI32IntegerAttr(rowMask),
203+
rewriter.getI32IntegerAttr(bankMask),
204+
rewriter.getBoolAttr(true))
205+
.getRes();
206+
207+
if (!valType.isF32()) {
208+
src = truncAndCastFromInt(rewriter, loc, src, valType, 32);
209+
dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32);
210+
}
211+
212+
IRMapping mapping;
213+
mapping.map(reduxOp->getOperand(0), src);
214+
mapping.map(reduxOp->getOperand(1), dppResult);
215+
return rewriter.clone(*reduxOp, mapping)->getResult(0);
216+
};
217+
218+
for (int i = 0; i < acc.size(); i++) {
219+
Value buf;
220+
auto valType = acc[i].getType();
221+
222+
/*
223+
Here's the implementation of full-wavefront reduction using dpp.
224+
https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
225+
226+
Each step has a v_mov_dpp instruction following the redux op. In
227+
some cases, the lower-level compiler could merge them into single
228+
instruction. For example, v_mov_dpp + max => v_max_dpp.
229+
230+
For gfx9, we have 64 threads per warp. These 64 threads are arranged
231+
into 4 rows, with each row being 16 threads. Each 16 threads are arranged
232+
further into 4 banks, with each bank being 4 threads. Overall it's in a
233+
(row, bank, thread) structure. When shuffling, we use row/bank mask to
234+
indicate which row/bank to participate. Then modifier like row_shr and
235+
row_bcast means exact data movement schemes. In the following
236+
instructions, taking row 0 as an example:
237+
238+
Step 1: Right shift for 8 lanes.
239+
lane 8-15 = redux(lane 0-7, lane 8-15)
240+
241+
Step 2: Right shift for 4 lanes.
242+
lane 12-15 = redux(lane 8-11, lane 12-15)
243+
244+
Step 3: Right shift for 2 lanes.
245+
lane 14-15 = redux(lane 12-13, lane 14-15)
246+
247+
Step 4: Right shift for 1 lane.
248+
lane 15 = redux(lane 14, lane 15)
249+
250+
Step 5: Broadcast lane 15 of each row to all the lanes of its next row.
251+
lane 16-31 = redux(lane 15, lane 16-31)
252+
253+
Step 6: Broadcast lane 31 to lane 32-63.
254+
lane 32-63 = redux(lane 31, lane 32-63)
255+
256+
Now the reduction result is stored in lane 63.
257+
258+
Step 7: Read the reduction result from lane 63 and broadcast with
259+
readlane.
260+
*/
261+
262+
const int allRows = 0xf;
263+
const int allBanks = 0xf;
264+
265+
const uint32_t dppCtrlRowShr = static_cast<uint32_t>(DppCtrl::ROW_SHR0);
266+
267+
// row_shr:8
268+
buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr,
269+
allRows, allBanks);
270+
271+
// row_shr:4
272+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr,
273+
allRows, allBanks);
274+
275+
// row_shr:2
276+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr,
277+
allRows, allBanks);
278+
279+
// row_shr:1
280+
buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr,
281+
allRows, allBanks);
282+
283+
// row_bcast:15 row_mask:0xa
284+
buf = createDppReduxOpWithBoundCtrl(
285+
valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks);
286+
287+
// row_bcast:31
288+
buf = createDppReduxOpWithBoundCtrl(valType, buf,
289+
static_cast<uint32_t>(DppCtrl::BCAST31),
290+
allRows, allBanks);
291+
292+
// Similarly, we need to cast data types for readlane instruction.
293+
Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16);
294+
295+
// Get reduction result from lane 63
296+
std::string intrinsic = "llvm.amdgcn.readlane";
297+
Value result =
298+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType,
299+
ValueRange{buf, i32_val(63)})
300+
->getResult(0);
301+
302+
result = truncAndCastFromInt(rewriter, loc, result, valType, 16);
303+
304+
acc[i] = result;
305+
}
306+
307+
return true;
134308
}
135309

136310
void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,

0 commit comments

Comments
 (0)