Skip to content

Commit aeac283

Browse files
authored
[TritonGPU] Enable accum-init optimization for unconditionally zero-ed accumulators (#6395)
Currently, the pass doesn't fire when [there is no explicit op that conditionally clears the accumulator](https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp#L207-L211). Thus, it misses the simplest case where this optimization is applicable - the accumulator is initialized to zero, and after the first iteration, the accumulator is always updated with +=. The motivation is an IR like below. We want to hoist tmem_alloc outside of the tile loop, but that requires explicitly clearing the accumulator after the K loop for one tile completes. Enabling this optimization for this case allows us to skip the explicit clearing. ``` for tile ... for k ... iter_args(arg9 = cst_zero) acc = tmem_alloc arg9 mma A B acc next_acc = tmem_load acc ... yield next_acc ``` --------- Co-authored-by: Masahiro Masuda <[email protected]>
1 parent 4aeaae5 commit aeac283

File tree

5 files changed

+136
-78
lines changed

5 files changed

+136
-78
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 91 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "mlir/Transforms/Passes.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "triton/Dialect/Triton/IR/OpInterfaces.h"
34
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
45
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
56
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -113,6 +114,19 @@ void setUseAccFlag(Operation *op, Value useAcc) {
113114
}
114115
}
115116

117+
Value getUseAccFlag(Operation *op) {
118+
assert(isa<DotOpInterface>(op) && "Expected a dot-like operation");
119+
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
120+
return wgDotOp.getUseC();
121+
} else if (auto tc05MmaOp =
122+
dyn_cast<triton::nvidia_gpu::MMAv5OpInterface>(op)) {
123+
return tc05MmaOp.useAccumulator();
124+
} else {
125+
assert(false && "Unexpected dot-like operation");
126+
}
127+
return nullptr;
128+
}
129+
116130
bool isConstantZeroTensor(Value v) {
117131
return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat()));
118132
}
@@ -157,6 +171,18 @@ findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) {
157171
return std::nullopt;
158172
}
159173

174+
std::optional<bool> getBoolFromConstant(Value cst) {
175+
auto constantOp = cst.getDefiningOp<arith::ConstantOp>();
176+
if (!constantOp) {
177+
return std::nullopt;
178+
}
179+
assert(constantOp.getValue());
180+
if (auto boolAttr = dyn_cast<BoolAttr>(constantOp.getValue())) {
181+
return boolAttr.getValue();
182+
}
183+
return std::nullopt;
184+
}
185+
160186
} // namespace
161187

162188
class OptimizeAccumulatorInitPass
@@ -206,62 +232,81 @@ class OptimizeAccumulatorInitPass
206232
bool loopArgIsZero = false;
207233
std::optional<std::pair<Operation *, int>> zeroInitOp =
208234
findZeroInitOp(accUse, forOp, loopArgIsZero);
209-
if (!zeroInitOp) {
235+
236+
if (!zeroInitOp && !loopArgIsZero) {
210237
continue;
211238
}
212239

240+
if (auto useAccValue = getUseAccFlag(mmaOp)) {
241+
auto useAcc = getBoolFromConstant(useAccValue);
242+
if (!useAcc || *useAcc == false) {
243+
// Do not run this optimization if there is already a non-constant
244+
// flag (this pass has already run), or if this MMA does not use the
245+
// accumulator (e.g. the peeled MMA in the prologue, the first dot
246+
// in attention)
247+
continue;
248+
}
249+
}
250+
213251
Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
214252
(void)addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
215253
loopArgFlagValue =
216254
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);
217255

218-
Value condition = nullptr;
219-
Value oldValue = nullptr;
220-
Value zeroValue = nullptr;
221-
bool thenInitsToZero = false;
222-
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
223-
condition = selOp.getCondition();
224-
oldValue = isConstantZeroTensor(selOp.getTrueValue())
225-
? selOp.getFalseValue()
226-
: selOp.getTrueValue();
227-
zeroValue = isConstantZeroTensor(selOp.getTrueValue())
228-
? selOp.getTrueValue()
229-
: selOp.getFalseValue();
230-
thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue());
231-
} else {
232-
assert(isa<scf::IfOp>(*zeroInitOp->first) && "Expected an if op");
233-
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
234-
unsigned resultIndex = zeroInitOp->second;
235-
condition = ifOp.getCondition();
236-
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
237-
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
238-
oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal;
239-
zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal;
240-
thenInitsToZero = isConstantZeroTensor(thenVal);
241-
}
256+
if (zeroInitOp) {
257+
Value condition = nullptr;
258+
Value oldValue = nullptr;
259+
Value zeroValue = nullptr;
260+
bool thenInitsToZero = false;
261+
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
262+
condition = selOp.getCondition();
263+
oldValue = isConstantZeroTensor(selOp.getTrueValue())
264+
? selOp.getFalseValue()
265+
: selOp.getTrueValue();
266+
zeroValue = isConstantZeroTensor(selOp.getTrueValue())
267+
? selOp.getTrueValue()
268+
: selOp.getFalseValue();
269+
thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue());
270+
} else {
271+
assert(isa<scf::IfOp>(*zeroInitOp->first) && "Expected an if op");
272+
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
273+
unsigned resultIndex = zeroInitOp->second;
274+
condition = ifOp.getCondition();
275+
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
276+
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
277+
oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal;
278+
zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal;
279+
thenInitsToZero = isConstantZeroTensor(thenVal);
280+
}
242281

243-
// Create a select op that updates the flag
244-
rewriter.setInsertionPoint(zeroInitOp->first);
245-
bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp);
246-
Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue;
247-
auto selectFlagOp = rewriter.create<arith::SelectOp>(
248-
loc, condition, thenInitsToZero ? vFalse : prevFlagValue,
249-
thenInitsToZero ? prevFlagValue : vFalse);
250-
setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue);
251-
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
252-
forYield->insertOperands(forYield->getNumOperands(),
253-
{zeroingBeforeMMA ? vTrue : selectFlagOp});
282+
// Create a select op that updates the flag
283+
rewriter.setInsertionPoint(zeroInitOp->first);
284+
bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp);
285+
Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue;
286+
auto selectFlagOp = rewriter.create<arith::SelectOp>(
287+
loc, condition, thenInitsToZero ? vFalse : prevFlagValue,
288+
thenInitsToZero ? prevFlagValue : vFalse);
289+
setUseAccFlag(mmaOp,
290+
zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue);
291+
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
292+
forYield->insertOperands(forYield->getNumOperands(),
293+
{zeroingBeforeMMA ? vTrue : selectFlagOp});
254294

255-
// Stop clearing out the accumulator with zero
256-
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
257-
rewriter.setInsertionPoint(selOp);
258-
rewriter.replaceOp(selOp, oldValue);
259-
} else {
260-
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
261-
int resultIndex = zeroInitOp->second;
262-
auto zeroingYield =
263-
thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield();
264-
zeroingYield.setOperand(resultIndex, oldValue);
295+
// Stop clearing out the accumulator with zero
296+
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
297+
rewriter.setInsertionPoint(selOp);
298+
rewriter.replaceOp(selOp, oldValue);
299+
} else {
300+
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
301+
int resultIndex = zeroInitOp->second;
302+
auto zeroingYield =
303+
thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield();
304+
zeroingYield.setOperand(resultIndex, oldValue);
305+
}
306+
} else if (loopArgIsZero) {
307+
setUseAccFlag(mmaOp, loopArgFlagValue);
308+
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
309+
forYield->insertOperands(forYield->getNumOperands(), vTrue);
265310
}
266311
}
267312

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@ using Partition = WarpSchedule::Partition;
2525
// partition or the provided `partition`.
2626
static void eraseOtherPartitions(scf::ForOp &loop, const WarpSchedule &schedule,
2727
const Partition *partition) {
28+
auto inPartition = [&](Operation *op) {
29+
const Partition *opPartition =
30+
schedule.getPartition(loop.getBody()->findAncestorOpInBlock(*op));
31+
return llvm::is_contained({partition, schedule.getRootPartition()},
32+
opPartition);
33+
};
2834
llvm::BitVector toErase(loop.getNumRegionIterArgs(), true);
2935
for (Operation &op :
3036
llvm::make_early_inc_range(loop.getBody()->without_terminator())) {
31-
const Partition *opPartition = schedule.getPartition(&op);
32-
if (!llvm::is_contained({partition, schedule.getRootPartition()},
33-
opPartition)) {
37+
if (!inPartition(&op)) {
3438
op.dropAllUses();
3539
op.erase();
3640
continue;
@@ -43,7 +47,9 @@ static void eraseOtherPartitions(scf::ForOp &loop, const WarpSchedule &schedule,
4347
}
4448
}
4549
for (auto [i, arg] : llvm::enumerate(loop.getRegionIterArgs())) {
46-
if (toErase.test(i))
50+
if (llvm::any_of(arg.getUsers(), inPartition))
51+
toErase.reset(i);
52+
else if (toErase.test(i))
4753
arg.dropAllUses();
4854
}
4955
eraseLoopCarriedValues(loop, std::move(toErase));

python/test/unit/language/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def block_scale_mxfp_matmul( #
469469
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4])
470470
@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True])
471471
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
472-
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device, monkeypatch):
472+
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device):
473473
if BLOCK_N == 256 and BLOCK_K == 256:
474474
NUM_STAGES = min(NUM_STAGES, 2)
475475
elif BLOCK_K == 256:

test/TritonGPU/accumulator-init.mlir

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -tritongpu-optimize-accumulator-init | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -292,42 +292,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
292292
tt.return %17 : tensor<128x16xf32, #mma1>
293293
}
294294

295-
// Check that we bail out in unsupported cases
296-
297-
// CHECK-LABEL: @non_zero_init
298-
// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
299-
tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> {
300-
%c0_i32 = arith.constant 0 : i32
301-
%cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1>
302-
%c1_i32 = arith.constant 1 : i32
303-
%c8_i32 = arith.constant 8 : i32
304-
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
305-
%cnd = arith.cmpi slt, %arg3, %ext : i32
306-
%acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1>
307-
%acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1>
308-
scf.yield %acc_: tensor<128x16xf32, #mma1>
309-
}
310-
tt.return %17 : tensor<128x16xf32, #mma1>
311-
}
312-
313-
// CHECK-LABEL: @zero_init_dist_2
314-
// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
295+
// CHECK-LABEL: @zero_init_dist_2
315296
tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> {
316297
%c0_i32 = arith.constant 0 : i32
298+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
317299
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
318300
%c1_i32 = arith.constant 1 : i32
319301
%c8_i32 = arith.constant 8 : i32
302+
// CHECK: scf.for {{.*}} = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg{{[1-9]+}} = %{{.*}}, %[[ACC:.*]] = %[[CST]], %[[INIT_FLAG:.*]] = %false)
320303
%17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 {
321304
%cnd = arith.cmpi slt, %arg3, %ext : i32
305+
// CHECK: %2 = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[INIT_FLAG]]
322306
%acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1>
323307
%acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1>
308+
// CHECK: scf.yield {{.*}}, {{.*}}, %true
324309
scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>
325310
}
326311
tt.return %17 : tensor<128x16xf32, #mma1>
327312
}
328313

329314
// CHECK-LABEL: @if_defines_alternative
330-
// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
315+
// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %arg{{.*}} : !ttg.memdesc
331316
tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> {
332317
%c0_i32 = arith.constant 0 : i32
333318
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
@@ -343,13 +328,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
343328
%acc_alt = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1>
344329
scf.yield %acc_alt : tensor<128x16xf32, #mma1>
345330
}
331+
// CHECK: scf.yield {{.*}}, %true
346332
scf.yield %acc_: tensor<128x16xf32, #mma1>
347333
}
348334
tt.return %17 : tensor<128x16xf32, #mma1>
349335
}
350336

351337
// CHECK-LABEL: @non_cond_override
352-
// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
338+
// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %arg{{.*}} : !ttg.memdesc
353339
tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> {
354340
%c0_i32 = arith.constant 0 : i32
355341
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
@@ -359,6 +345,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
359345
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
360346
%acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1>
361347
%acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1>
348+
// CHECK: scf.yield {{.*}}, %true
349+
scf.yield %acc_: tensor<128x16xf32, #mma1>
350+
}
351+
tt.return %17 : tensor<128x16xf32, #mma1>
352+
}
353+
354+
355+
// Check that we bail out in unsupported cases
356+
357+
// CHECK-LABEL: @non_zero_init
358+
// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc
359+
tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> {
360+
%c0_i32 = arith.constant 0 : i32
361+
%cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1>
362+
%c1_i32 = arith.constant 1 : i32
363+
%c8_i32 = arith.constant 8 : i32
364+
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
365+
%cnd = arith.cmpi slt, %arg3, %ext : i32
366+
%acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1>
367+
%acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1>
362368
scf.yield %acc_: tensor<128x16xf32, #mma1>
363369
}
364370
tt.return %17 : tensor<128x16xf32, #mma1>

0 commit comments

Comments
 (0)