Skip to content

Commit ca37374

Browse files
authored
Reland "[AMD] Optimize reduction with v_permlane intrinsics in GFX950" (#7321)
triton-lang/triton#7291 fixed the LLVM issue that caused correctness problems. Now we can reland this patch.
1 parent 7c68944 commit ca37374

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
23

34
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
45
// CHECK-LABEL: atomic_add_f32_scalar
@@ -409,3 +410,34 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
409410
tt.return
410411
}
411412
}
413+
414+
// -----
415+
416+
// GFX950-LABEL: reduce_32x32
417+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
418+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
419+
tt.func @reduce_32x32(%arg0: tensor<64x32xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>>) {
420+
%3101 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
421+
^bb0(%arg24: f32, %arg25: f32):
422+
%3166 = "arith.maxnumf"(%arg24, %arg25) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
423+
"tt.reduce.return"(%3166) : (f32) -> ()
424+
}) : (tensor<64x32xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>}>>
425+
tt.return
426+
}
427+
}
428+
429+
// -----
430+
431+
// GFX950-LABEL: reduce_16x16
432+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
433+
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
434+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
435+
tt.func @reduce_16x16(%arg0: tensor<64x16xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>>){
436+
%1 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
437+
^bb0(%arg24: f32, %arg25: f32):
438+
%3166 = "arith.maxnumf"(%arg24, %arg25) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
439+
"tt.reduce.return"(%3166) : (f32) -> ()
440+
}) : (tensor<64x16xf32, #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>}>>
441+
tt.return
442+
}
443+
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,96 @@ static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc,
225225
return toVal;
226226
}
227227

228+
// Permute lanes of the input val and apply reduction to permuted values.
229+
static Value permuteAndReduce(RewriterBase &rewriter, Location loc,
230+
StringRef intrinsic, Value val,
231+
Operation *reduxOp) {
232+
Type valType = val.getType();
233+
assert(valType.getIntOrFloatBitWidth() <= 32);
234+
235+
Type actualType = valType;
236+
if (!valType.isInteger(32))
237+
actualType = castToAndSExtInt(rewriter, loc, val, valType, 32);
238+
239+
auto b = TritonLLVMOpBuilder(loc, rewriter);
240+
Value falseVal = b.false_val();
241+
MLIRContext *ctx = rewriter.getContext();
242+
Type retType = struct_ty({i32_ty, i32_ty});
243+
Value perm =
244+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, retType,
245+
ValueRange{val, val, falseVal, falseVal})
246+
->getResult(0);
247+
Value v0 = b.extract_val(i32_ty, perm, 0);
248+
Value v1 = b.extract_val(i32_ty, perm, 1);
249+
250+
if (!valType.isInteger(32)) {
251+
v0 = truncAndCastFromInt(rewriter, loc, v0, valType, 32);
252+
v1 = truncAndCastFromInt(rewriter, loc, v1, valType, 32);
253+
}
254+
IRMapping mapping;
255+
mapping.map(reduxOp->getOperand(0), v0);
256+
mapping.map(reduxOp->getOperand(1), v1);
257+
Value redx = rewriter.clone(*reduxOp, mapping)->getResult(0);
258+
return redx;
259+
}
260+
261+
// Apply warp reduction across lanes using llvm intrinsics in GFX950.
262+
// The input acc has the partial accumulated values from reduction within
263+
// threads. The output acc has the final accumulated values.
264+
//
265+
// Two special cases are supported:
266+
// When numLaneToReduce == 2 && interleave == 32:
267+
// step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
268+
// the row 0 and 1 of the copy of acc
269+
// step 2: apply reduction to the result values to get final result
270+
// When numLaneToReduce == 4 && interleave == 16:
271+
// step 1: use permlane32_swap() to swap the row 2 and 3 of acc and
272+
// the row 0 and 1 of the copy of acc
273+
// step 2: apply reduction to the result values to get the partial result
274+
// step 3: use permlane16_swap() to swap the odd and even rows of
275+
// the partial results
276+
// step 4: apply reduction to get the final results
277+
static bool warpReduceSwap16or32(RewriterBase &rewriter, Location loc,
278+
SmallVector<Value> &acc, triton::ReduceOp op,
279+
unsigned numLaneToReduce,
280+
unsigned interleave) {
281+
Operation *reduxOp = op.getSingleCombiner();
282+
if (!reduxOp)
283+
return false;
284+
285+
bool mfma32Case = numLaneToReduce == 2 && interleave == 32;
286+
bool mfma16Case = numLaneToReduce == 4 && interleave == 16;
287+
if (!(mfma32Case || mfma16Case))
288+
return false;
289+
290+
Value val = acc[0];
291+
unsigned bits = val.getType().getIntOrFloatBitWidth();
292+
if (bits > 32)
293+
return false;
294+
295+
StringRef intrinsic = "llvm.amdgcn.permlane32.swap";
296+
for (auto i = 0; i < acc.size(); i++) {
297+
Value redx = permuteAndReduce(rewriter, loc, intrinsic, acc[i], reduxOp);
298+
299+
if (mfma16Case) {
300+
intrinsic = "llvm.amdgcn.permlane16.swap";
301+
redx = permuteAndReduce(rewriter, loc, intrinsic, redx, reduxOp);
302+
}
303+
304+
acc[i] = redx;
305+
}
306+
return true;
307+
}
308+
228309
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
229310
SmallVector<Value> &acc, triton::ReduceOp op,
230311
unsigned numLaneToReduce,
231312
unsigned interleave) const {
232313
auto b = TritonLLVMOpBuilder(loc, rewriter);
233314

315+
if (isCDNA() && getISAFamily() == ISAFamily::CDNA4 &&
316+
warpReduceSwap16or32(rewriter, loc, acc, op, numLaneToReduce, interleave))
317+
return true;
234318
if (numLaneToReduce != getWarpSize())
235319
return false;
236320
if (isCDNA() && getISAFamily() == ISAFamily::CDNA1)

0 commit comments

Comments
 (0)