Skip to content

Commit d52153c

Browse files
authored
[AMD] Optimize reduction with v_permlane intrinsics in GFX950 (#6594)
This helps to improve attention performance on gfx950.
1 parent 60ceeff commit d52153c

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
23

34
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
45
// CHECK-LABEL: atomic_add_f32_scalar
@@ -380,3 +381,33 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
380381
tt.return
381382
}
382383
}
384+
385+
// -----
386+
// GFX950-LABEL: reduce_32x32
387+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
388+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
389+
tt.func @reduce_32x32(%arg0: tensor<64x32xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>>) {
390+
%3101 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
391+
^bb0(%arg24: f32, %arg25: f32):
392+
%3166 = "arith.maxnumf"(%arg24, %arg25) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
393+
"tt.reduce.return"(%3166) : (f32) -> ()
394+
}) : (tensor<64x32xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>}>>
395+
tt.return
396+
}
397+
}
398+
399+
// -----
400+
401+
// GFX950-LABEL: reduce_16x16
402+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
403+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
404+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
405+
tt.func @reduce_16x16(%arg0: tensor<64x16xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>>){
406+
%1 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
407+
^bb0(%arg24: f32, %arg25: f32):
408+
%3166 = "arith.maxnumf"(%arg24, %arg25) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
409+
"tt.reduce.return"(%3166) : (f32) -> ()
410+
}) : (tensor<64x16xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>}>>
411+
tt.return
412+
}
413+
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,96 @@ static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
224224
return toVal;
225225
}
226226

227+
// Permute lanes of the input val and apply reduction to permuted values.
228+
static Value permuteAndReduce(RewriterBase &rewriter, Location loc,
229+
StringRef intrinsic, Value val,
230+
Operation *reduxOp) {
231+
Type valType = val.getType();
232+
assert(valType.getIntOrFloatBitWidth() <= 32);
233+
234+
Type actualType = valType;
235+
if (!valType.isInteger(32))
236+
actualType = castToAndSExtInt(rewriter, loc, val, valType, 32);
237+
238+
auto b = TritonLLVMOpBuilder(loc, rewriter);
239+
Value falseVal = b.false_val();
240+
MLIRContext *ctx = rewriter.getContext();
241+
Type retType = struct_ty({i32_ty, i32_ty});
242+
Value perm =
243+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, retType,
244+
ValueRange{val, val, falseVal, falseVal})
245+
->getResult(0);
246+
Value v0 = b.extract_val(i32_ty, perm, 0);
247+
Value v1 = b.extract_val(i32_ty, perm, 1);
248+
249+
if (!valType.isInteger(32)) {
250+
v0 = truncAndCastFromInt(rewriter, loc, v0, valType, 32);
251+
v1 = truncAndCastFromInt(rewriter, loc, v1, valType, 32);
252+
}
253+
IRMapping mapping;
254+
mapping.map(reduxOp->getOperand(0), v0);
255+
mapping.map(reduxOp->getOperand(1), v1);
256+
Value redx = rewriter.clone(*reduxOp, mapping)->getResult(0);
257+
return redx;
258+
}
259+
260+
// Apply warp reduction across lanes using llvm intrinsics in GFX950.
261+
// The input acc has the partial accumulated values from reduction within
262+
// threads. The output acc has the final accumulated values.
263+
//
264+
// Two special cases are supported:
265+
// When numLaneToReduce == 2 && interleave == 32:
266+
// step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
267+
// the row 0 and 1 of the copy of acc
268+
// step 2: apply reduction to the result values to get final result
269+
// When numLaneToReduce == 4 && interleave == 16:
270+
// step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
271+
// the row 0 and 1 of the copy of acc
272+
// step 2: apply reduction to the result values to get the partial result
273+
// step 3: use permlane16_swap() to swap the odd and even rows of
274+
// the partial results
275+
// step 4: apply reduction to get the final results
276+
static bool warpReduceSwap16or32(RewriterBase &rewriter, Location loc,
277+
SmallVector<Value> &acc, triton::ReduceOp op,
278+
unsigned numLaneToReduce,
279+
unsigned interleave) {
280+
Operation *reduxOp = op.getSingleCombiner();
281+
if (!reduxOp)
282+
return false;
283+
284+
bool mfma32Case = numLaneToReduce == 2 && interleave == 32;
285+
bool mfma16Case = numLaneToReduce == 4 && interleave == 16;
286+
if (!(mfma32Case || mfma16Case))
287+
return false;
288+
289+
Value val = acc[0];
290+
unsigned bits = val.getType().getIntOrFloatBitWidth();
291+
if (bits > 32)
292+
return false;
293+
294+
StringRef intrinsic = "llvm.amdgcn.permlane32.swap";
295+
for (auto i = 0; i < acc.size(); i++) {
296+
Value redx = permuteAndReduce(rewriter, loc, intrinsic, acc[i], reduxOp);
297+
298+
if (mfma16Case) {
299+
intrinsic = "llvm.amdgcn.permlane16.swap";
300+
redx = permuteAndReduce(rewriter, loc, intrinsic, redx, reduxOp);
301+
}
302+
303+
acc[i] = redx;
304+
}
305+
return true;
306+
}
307+
227308
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
228309
SmallVector<Value> &acc, triton::ReduceOp op,
229310
unsigned numLaneToReduce,
230311
unsigned interleave) const {
231312
auto b = TritonLLVMOpBuilder(loc, rewriter);
232313

314+
if (isCDNA() && getISAFamily() == ISAFamily::CDNA4 &&
315+
warpReduceSwap16or32(rewriter, loc, acc, op, numLaneToReduce, interleave))
316+
return true;
233317
if (numLaneToReduce != getWarpSize())
234318
return false;
235319
if (isCDNA() && getISAFamily() == ISAFamily::CDNA1)

0 commit comments

Comments
 (0)