Skip to content

Commit 4207ca4

Browse files
authored
[BACKEND] Add min/max redux optimization for Blackwell (#7465)
implements the new op described here: https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync
1 parent 46f4bbd commit 4207ca4

File tree

3 files changed

+74
-30
lines changed

3 files changed

+74
-30
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm 2>/dev/null | FileCheck %s --dump-input-context 20
22

33
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
44
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>)
@@ -1739,28 +1739,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
17391739
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
17401740
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
17411741
module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
1742-
tt.func public @sum_reduction(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1743-
%cst = arith.constant dense<1024> : tensor<1x1xi32, #blocked>
1744-
%0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1>
1745-
%1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1746-
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked>
1747-
%3 = arith.muli %2, %cst : tensor<1x1xi32, #blocked>
1748-
%4 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x!tt.ptr<i32>, #blocked>
1749-
%5 = tt.addptr %4, %3 : tensor<1x1x!tt.ptr<i32>, #blocked>, tensor<1x1xi32, #blocked>
1750-
%6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1751-
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked>
1752-
%8 = tt.broadcast %5 : tensor<1x1x!tt.ptr<i32>, #blocked> -> tensor<1x1024x!tt.ptr<i32>, #blocked>
1753-
%9 = tt.addptr %8, %7 : tensor<1x1024x!tt.ptr<i32>, #blocked>, tensor<1x1024xi32, #blocked>
1754-
%10 = tt.load %9 : tensor<1x1024x!tt.ptr<i32>, #blocked>
1755-
%11 = "tt.reduce"(%10) <{axis = 1 : i32}> ({
1742+
tt.func public @sum_reduction(%arg0: tensor<1x1024xi32, #blocked>) {
1743+
%11 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
17561744
^bb0(%arg2: i32, %arg3: i32):
17571745
%15 = arith.addi %arg2, %arg3 : i32
17581746
tt.reduce.return %15 : i32
17591747
}) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1760-
%12 = ttg.convert_layout %11 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1>
1761-
%13 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>, #blocked1>
1762-
%14 = tt.addptr %13, %0 : tensor<1x!tt.ptr<i32>, #blocked1>, tensor<1xi32, #blocked1>
1763-
tt.store %14, %12 : tensor<1x!tt.ptr<i32>, #blocked1>
17641748
tt.return
17651749
}
17661750
}

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,47 @@ tt.func @load_store_x1_unpacked(%arg0: !ttg.memdesc<128x2xf16, #tmem_x1_unpacked
683683
}
684684

685685
}
686+
687+
// -----
688+
689+
// CHECK-LABEL: max_reduction
690+
// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32
691+
// CHECK: nvvm.redux.sync fmax %{{.*}}, %[[M]] {nan = true} : f32 -> f32
692+
// CHECK: nvvm.barrier0
693+
// CHECK: nvvm.shfl.sync bfly
694+
// CHECK: nvvm.shfl.sync bfly
695+
// CHECK: nvvm.barrier0
696+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
697+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
698+
module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
699+
tt.func public @max_reduction(%arg0: tensor<1x1024xf32, #blocked>) {
700+
%11 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
701+
^bb0(%arg2: f32, %arg3: f32):
702+
%15 = arith.maximumf %arg2, %arg3 : f32
703+
tt.reduce.return %15 : f32
704+
}) {allocation.offset = 0 : i32} : (tensor<1x1024xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
705+
tt.return
706+
}
707+
}
708+
709+
// -----
710+
711+
// CHECK-LABEL: maxnum_reduction
712+
// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32
713+
// CHECK: nvvm.redux.sync fmax %{{.*}}, %[[M]] : f32 -> f32
714+
// CHECK: nvvm.barrier0
715+
// CHECK: nvvm.shfl.sync bfly
716+
// CHECK: nvvm.shfl.sync bfly
717+
// CHECK: nvvm.barrier0
718+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
719+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
720+
module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
721+
tt.func public @maxnum_reduction(%arg0: tensor<1x1024xf32, #blocked>) {
722+
%11 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
723+
^bb0(%arg2: f32, %arg3: f32):
724+
%15 = arith.maxnumf %arg2, %arg3 : f32
725+
tt.reduce.return %15 : f32
726+
}) {allocation.offset = 0 : i32} : (tensor<1x1024xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
727+
tt.return
728+
}
729+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,22 @@ namespace mlir::triton::NVIDIA {
8787

8888
// Check if the reduction can use a redux op and return the kind.
8989
static std::optional<NVVM::ReduxKind> matchReduxKind(triton::ReduceOp op,
90-
int computeCapability) {
90+
int computeCapability,
91+
bool &useNanQualifier) {
92+
useNanQualifier = false;
9193
if (computeCapability < 80)
9294
return std::nullopt;
9395
Operation *reduceOp = op.getSingleCombiner();
9496
if (!reduceOp)
9597
return std::nullopt;
98+
if (computeCapability >= 100 && reduceOp->getResultTypes()[0].isF32()) {
99+
if (isa<arith::MinimumFOp, arith::MaximumFOp>(reduceOp))
100+
useNanQualifier = true;
101+
if (isa<arith::MaxNumFOp, arith::MaximumFOp>(reduceOp))
102+
return NVVM::ReduxKind::FMAX;
103+
if (isa<arith::MinNumFOp, arith::MinimumFOp>(reduceOp))
104+
return NVVM::ReduxKind::FMIN;
105+
}
96106
auto intType = dyn_cast<IntegerType>(reduceOp->getResultTypes()[0]);
97107
if (!intType || intType.getWidth() > 32)
98108
return std::nullopt;
@@ -434,7 +444,8 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
434444
unsigned numLaneToReduce,
435445
unsigned interleave) const {
436446
auto b = TritonLLVMOpBuilder(loc, rewriter);
437-
if (auto kind = matchReduxKind(op, computeCapability)) {
447+
bool useNanQualifier = false;
448+
if (auto kind = matchReduxKind(op, computeCapability, useNanQualifier)) {
438449
// Based on benchmarking on A100 redux op gives a speed up only when doing
439450
// a single reduction (not partitioned) and when the mask is static.
440451
// Therefore we currently only enable it to reduce across all the lanes.
@@ -452,17 +463,22 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
452463
b.and_(laneId, b.i32_val(~(numLaneToReduce - 1))));
453464
}
454465
for (unsigned i = 0; i < acc.size(); ++i) {
455-
unsigned bitwidth = cast<IntegerType>(acc[i].getType()).getWidth();
456-
if (bitwidth < 32) {
457-
if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX)
458-
acc[i] = b.sext(i32_ty, acc[i]);
459-
else
460-
acc[i] = b.zext(i32_ty, acc[i]);
466+
unsigned bitwidth = acc[i].getType().getIntOrFloatBitWidth();
467+
if (acc[i].getType().isInteger()) {
468+
if (bitwidth < 32) {
469+
if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX)
470+
acc[i] = b.sext(i32_ty, acc[i]);
471+
else
472+
acc[i] = b.zext(i32_ty, acc[i]);
473+
}
461474
}
462475
acc[i] = rewriter.create<NVVM::ReduxOp>(loc, acc[i].getType(), acc[0],
463-
*kind, mask);
464-
if (bitwidth < 32)
465-
acc[i] = b.trunc(int_ty(bitwidth), acc[i]);
476+
*kind, mask, /*abs=*/false,
477+
/*nan=*/useNanQualifier);
478+
if (acc[i].getType().isInteger()) {
479+
if (bitwidth < 32)
480+
acc[i] = b.trunc(int_ty(bitwidth), acc[i]);
481+
}
466482
}
467483
return true;
468484
}

0 commit comments

Comments
 (0)