Skip to content

Commit 366de71

Browse files
authored
[TritonGPU] Tweaks to warp specialization to reduce register pressure (#6403)
* Place TMEM accumulator acquire over the entire epilogue to improve instruction scheduling * Give the load partition 2 warps This marginally improves the performance of the tutorial matmul (~2.5%) but is important for causes where spilling may occur
1 parent 8de17d2 commit 366de71

File tree

4 files changed

+46
-16
lines changed

4 files changed

+46
-16
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,29 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
381381
donePt.getPoint()->isBeforeInBlock(&*b.getInsertionPoint()));
382382
donePt = b.saveInsertionPoint();
383383

384-
// Acquire and get the accumulator result.
385-
b.setInsertionPoint(domOp);
386384
Partition *userPartition = schedule.addPartition(numStages + numMmaStages);
385+
// Acquire and get the accumulator result. Normally, we want to acquire the
386+
// accumulator for as small of a critical section as possible to unblock
387+
// dependents, but if the most dominating user is inside a conditional,
388+
// acquire the accumulator for the whole branch. This will improve
389+
// instruction scheduling and interleaving of the TMEM load.
390+
bool userInConditional = isa<scf::IfOp>(domOp->getParentOp());
391+
b.setInsertionPoint(domOp);
392+
if (userInConditional)
393+
b.setInsertionPointToStart(domOp->getBlock());
387394
createInPartition<ttng::WaitBarrierOp>(b, *userPartition, curAccReadyBar,
388395
accPhase);
396+
397+
b.setInsertionPoint(domOp);
389398
Value acc = createInPartition<ttng::TMEMLoadOp>(
390399
b, *userPartition, info.accLoad.getType(), curAccBuf);
391400
for (Operation *user : accUses)
392401
user->replaceUsesOfWith(info.accLoad, acc);
402+
393403
// Signal the accumulator buffer is ready for the next iteration. Because
394404
// the mbarriers got shifted over by 1, we have to signal the next mbarrier.
405+
if (userInConditional)
406+
b.setInsertionPoint(domOp->getBlock()->getTerminator());
395407
Value nextIndex =
396408
b.create<arith::AddIOp>(accIndex, intCst(numMmaStages - 1));
397409
nextIndex = b.create<arith::RemUIOp>(nextIndex, intCst(numMmaStages));

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include "triton/Analysis/AxisInfo.h"
66
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
77
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
89
#include "llvm/ADT/ScopeExit.h"
910

1011
using namespace mlir;
1112
using namespace triton;
1213
using namespace triton::gpu;
14+
namespace ttng = triton::nvidia_gpu;
1315

1416
//===----------------------------------------------------------------------===//
1517
// relayoutWarps
@@ -182,14 +184,28 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
182184
// If the compiler could control that, then we could allow non-uniform
183185
// register distributions, mostly beneficial for single-warp warpgroups that
184186
// just do some artihmetic.
185-
constexpr unsigned nTotalRegs = 65536; // for Blackwell SMs
187+
constexpr unsigned nTotalRegs = 1 << 16; // for Blackwell SMs
186188
const unsigned threadsPerWarp =
187189
TritonGPUDialect::getThreadsPerWarp(axisInfo.getModuleOp());
188190
const unsigned defaultNumWarps = lookupNumWarps(wsOp);
189191

190192
SmallVector<int32_t> partitionNumWarps =
191193
llvm::to_vector(wsOp.getPartitionNumWarps());
192194

195+
// Some instructions have critical throughput if have low register usage. Make
196+
// sure there are enough warps for these ops to execute quickly.
197+
SmallVector<int32_t> minWarpsForPartition(partitionNumWarps.size(), 1);
198+
for (auto [minWarps, region] :
199+
llvm::zip(minWarpsForPartition, wsOp.getPartitionRegions())) {
200+
region->walk([minWarps = &minWarps](Operation *op) {
201+
if (!isa<scf::ForOp>(op->getParentOp()))
202+
return;
203+
if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp,
204+
ttng::AsyncTMACopyGlobalToLocalOp>(op))
205+
*minWarps = 2;
206+
});
207+
}
208+
193209
bool changed;
194210
do {
195211
changed = false;
@@ -215,9 +231,9 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
215231
int32_t curTotalNumWarps = std::accumulate(
216232
partitionNumWarps.begin(), partitionNumWarps.end(), defaultNumWarps);
217233

218-
for (auto [numWarps, tensorRegs] :
219-
llvm::zip(partitionNumWarps, maxTensorRegs)) {
220-
if (numWarps == 1)
234+
for (auto [minWarps, numWarps, tensorRegs] :
235+
llvm::zip(minWarpsForPartition, partitionNumWarps, maxTensorRegs)) {
236+
if (numWarps <= minWarps)
221237
continue;
222238
// Check if reducing the number of warps will still fit the tensor. If it
223239
// didn't fit to begin with, it won't fit after shrinking.

test/TritonGPU/automatic-warp-specialization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ tt.func @matmul_change_desc_in_prologue(
3232
// BASE-NOT: tt.make_tensor_descriptor
3333
// PIPELINE-NOT: tt.experimental_tensormap_create
3434
// CHECK-LABEL: partition1
35-
// CHECK-SAME: num_warps(1)
35+
// CHECK-SAME: num_warps(2)
3636
// BASE-COUNT-2: tt.make_tensor_descriptor
3737
// PIPELINE-COUNT-2: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32}
3838
// PIPELINE-COUNT-2: tt.experimental_tensormap_create
@@ -87,7 +87,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
8787
// CHECK-LABEL: partition0
8888
// CHECK-SAME: num_warps(1)
8989
// CHECK-LABEL: partition1
90-
// CHECK-SAME: num_warps(1)
90+
// CHECK-SAME: num_warps(2)
9191
// CHECK: [[INDICES:%.*]] = tt.splat %{{.*}} : i32 -> tensor<128xi32,
9292
// CHECK: ttng.async_tma_gather %{{.*}}[[[INDICES]],
9393
// CHECK-LABEL: partition2

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,10 +484,10 @@ tt.func @matmul_tma_acc_with_conditional_user(
484484
scf.if %do_epilogue {
485485
// CHECK-NEXT: ttng.wait_barrier [[CUR_ACC_READY_BAR]], [[ACC_PHASE]] {ttg.partition = 0 : i32}
486486
// CHECK-NEXT: [[C:%.*]] = ttng.tmem_load [[ACC_BUF]] {ttg.partition = 0 : i32}
487-
// CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[NEXT_ACC_INDEX]]]
488-
// CHECK-NEXT: ttng.arrive_barrier [[NEXT_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32}
489487
// CHECK-NEXT: "acc_user"([[C]])
490488
"acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
489+
// CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[NEXT_ACC_INDEX]]]
490+
// CHECK-NEXT: ttng.arrive_barrier [[NEXT_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32}
491491
// CHECK-NEXT: } {ttg.partition = 0 : i32}
492492
}
493493

@@ -513,7 +513,7 @@ tt.func @matmul_tma_acc_with_conditional_user(
513513

514514
// AWS: ttg.warp_specialize
515515
// AWS: num_warps(4)
516-
// AWS: num_warps(1)
516+
// AWS: num_warps(2)
517517
// AWS: num_warps(1)
518518

519519
// CHECK: @matmul_tma_acc_with_conditional_def
@@ -612,7 +612,7 @@ tt.func @matmul_tma_acc_with_conditional_def(
612612

613613
// AWS: ttg.warp_specialize
614614
// AWS: num_warps(4)
615-
// AWS: num_warps(1)
615+
// AWS: num_warps(2)
616616
// AWS: num_warps(1)
617617

618618
// CHECK: @matmul_tma_acc_with_conditional_def_and_use
@@ -682,10 +682,10 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
682682
scf.if %do_epilogue {
683683
// CHECK-NEXT: ttng.wait_barrier [[CUR_ACC_READY_BAR]], [[ACC_PHASE]] {ttg.partition = 0 : i32}
684684
// CHECK-NEXT: [[C:%.*]] = ttng.tmem_load [[ACC_BUF]] {ttg.partition = 0 : i32}
685-
// CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[NEXT_ACC_INDEX]]]
686-
// CHECK-NEXT: ttng.arrive_barrier [[NEXT_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32}
687685
// CHECK-NEXT: "acc_user"([[C]])
688686
"acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
687+
// CHECK-NEXT: [[NEXT_ACC_EMPTY_BAR:%.*]] = ttg.memdesc_subview [[ACC_EMPTY_BUFS]][[[NEXT_ACC_INDEX]]]
688+
// CHECK-NEXT: ttng.arrive_barrier [[NEXT_ACC_EMPTY_BAR]], 1 {ttg.partition = 0 : i32}
689689
// CHECK-NEXT: } {ttg.partition = 0 : i32}
690690
}
691691

@@ -714,7 +714,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
714714

715715
// AWS: ttg.warp_specialize
716716
// AWS: num_warps(1)
717-
// AWS: num_warps(1)
717+
// AWS: num_warps(2)
718718
// AWS: num_warps(1)
719719

720720
// CHECK: @matmul_tma_acc_with_conditional_def_and_use_no_multibuf
@@ -791,10 +791,12 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(
791791
// CHECK-NEXT: scf.if [[DO_EPILOGUE]]
792792
scf.if %do_epilogue {
793793
// CHECK-NEXT: ttng.wait_barrier [[ACC_READY_BUF0]], [[ACC_PHASE]] {ttg.partition = 0 : i32}
794+
// CHECK-NEXT: "some_op"()
795+
"some_op"() : () -> ()
794796
// CHECK-NEXT: [[C:%.*]] = ttng.tmem_load [[ACC_BUF]] {ttg.partition = 0 : i32}
795-
// CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 {ttg.partition = 0 : i32}
796797
// CHECK-NEXT: "acc_user"([[C]])
797798
"acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
799+
// CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 {ttg.partition = 0 : i32}
798800
// CHECK-NEXT: } {ttg.partition = 0 : i32}
799801
}
800802

0 commit comments

Comments
 (0)