Skip to content

Commit 427d774

Browse files
authored
Fix incorrect codegen for masks dependent on loop induction variable (#310)
Currently, for mask dependent on loop induction variable, we compute the mask offset by adding the mask offset *before* coming into the loop by the loop iter-arg. This is not correct when the offset has an initial value other than 0 because then the value of the offset will always be one iteration *after* the current iteration. This patch fixes the codegen and adds tests for these scenarios.
1 parent 403a8a1 commit 427d774

File tree

3 files changed

+164
-18
lines changed

3 files changed

+164
-18
lines changed

lib/Analysis/MaskAnalysis.cpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton-shared/Analysis/MaskAnalysis.h"
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/SCF/IR/SCF.h"
11+
#include "mlir/IR/Builders.h"
1112
#include "mlir/Support/LogicalResult.h"
1213

1314
#include "triton-shared/Analysis/OpFoldResultUtils.h"
@@ -452,32 +453,65 @@ LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc,
452453
return failure();
453454
}
454455

456+
// This is a bit of a hack!!
457+
//
458+
// The offset (MaskState::start) of a mask can now depend on a loop's
459+
// iter-arg like the following example:
460+
//
461+
// idx = offset + tl.arange(0, 4)
462+
// for it in range(n):
463+
// mask = idx < size
464+
// x = tl.load(x_ptr + idx, mask=mask)
465+
// tl.store(y_ptr + idx, x, mask=mask)
466+
// idx += 4
467+
//
468+
// See
469+
// test/Conversion/TritonToStructured/mask_loop_iter_arg.mlir and
470+
// and
471+
// python/examples/test_mask_loop_iter_arg.py
472+
// for IR and full triton code.
473+
//
474+
// To support this case, we first make the following assumptions:
475+
// - MaskAnalysis is runs after PtrAnalysis's prepass finishes, which means
476+
// the offset for the load and store pointers have already been set up
477+
// at `argIndex + 1`
478+
// - The tensor of indices used by the load / store and the mask are the same
479+
// (see above where `idx` appears in both the mask and the pointer
480+
// arithmetic). This allows us to use the offset at `argIndex + 1` in the
481+
// above assumption. In the future, to make this more robust, we need to
482+
// verify that the offsets are indeed the same. Or alternatively, make sure
483+
// to generate a separate start and end offset for each mask that is being
484+
// updated in loops.
485+
//
486+
// Now to generate the mask state in each loop iteration, we first construct
487+
// the mask state *before* coming into the loop by parsing the init-arg. A
488+
// mask dimensions stay consistent throughout each loop iteration, but its
489+
// starting offset (`MaskState::start`) will change. So to construct the mask
490+
// state for each iteration, we need to make MaskState::state be the offset
491+
// iter-arg at `argIndex + 1`. Now for `MaskState::end`, we can first compute
492+
// the distance between `start` and `end` before coming into the loop, then
493+
// use this distance to compute the actual `end` in each loop.
455494
auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it);
456495
auto initArg = forOp.getInitArgs()[argIndex];
457496
if (auto getStateOp = initArg.getDefiningOp<tts::GetStructuredStateOp>()) {
458497
auto tritonValue = getStateOp->getOperand(0);
459498
MaskState lhsState;
460-
if (failed(lhsState.parse(tritonValue, loc, builder))) {
461-
return failure();
462-
}
463499

464-
// This is a bit of a hack!!
465-
//
466-
// The offsets and dimensions of a MaskState can now depend on a loop's
467-
// iter-arg.
468-
//
469-
// Because the PtrAnalysis's pre-pass already sets up the offsets,
470-
// we can create a new MaskState for each loop iteration by adding the
471-
// original MaskState with the current iter-arg, which is at `argIndex +
472-
// 1`.
473-
//
474-
// This will not work for nested loop scenarios, which would need a
475-
// more robust implementation.
476-
if (failed(this->addStateScalar(
477-
lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) {
478-
return failure();
500+
{
501+
OpBuilder::InsertionGuard guard(builder);
502+
// Make sure all ops generated for the mask state are inserted before
503+
// the current loop
504+
builder.setInsertionPoint(forOp);
505+
if (failed(lhsState.parse(tritonValue, loc, builder))) {
506+
return failure();
507+
}
479508
}
480509

510+
auto dist = subOFRs(lhsState.end, lhsState.start, loc, builder);
511+
this->start = forOp.getRegionIterArg(argIndex + 1);
512+
this->end = addOFRs(this->start, dist, loc, builder);
513+
this->dims = lhsState.dims;
514+
481515
return success();
482516
}
483517

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import triton
3+
import pytest
4+
5+
import triton.language as tl
6+
7+
@triton.jit
8+
def mask_loop(
9+
y_ptr,
10+
x_ptr,
11+
scale_ptr,
12+
size: torch.int64,
13+
BLOCK_SIZE: tl.constexpr,
14+
):
15+
bidx = tl.program_id(0)
16+
tidx = tl.arange(0, BLOCK_SIZE)
17+
18+
grid_stride = tl.num_programs(0) * BLOCK_SIZE
19+
iterations = tl.cdiv(size, 4)
20+
21+
idx = bidx * BLOCK_SIZE + tidx
22+
idy = idx + 1
23+
for it in range(iterations):
24+
mask = idx < size
25+
x = tl.load(x_ptr + idx, mask=mask).to(tl.float32)
26+
tl.store(y_ptr + idx, x, mask=mask)
27+
idx += grid_stride
28+
29+
30+
@pytest.mark.parametrize(
31+
"b",
32+
[
33+
1,
34+
2,
35+
3,
36+
8,
37+
2048,
38+
4096,
39+
],
40+
)
41+
@pytest.mark.parametrize(
42+
"h",
43+
[
44+
16,
45+
128,
46+
1024,
47+
5120,
48+
7680,
49+
8192,
50+
],
51+
)
52+
def test_mask_loop(b, h, device):
53+
x = torch.randn((b, h), dtype=torch.float32, device=device)
54+
y = torch.empty_like(x, dtype=torch.float32, device=device)
55+
scale_ones = torch.ones(1, dtype=torch.float32, device=device)
56+
57+
BLOCK_SIZE = 2
58+
59+
grid = (2,)
60+
61+
compiled = mask_loop[grid](
62+
y,
63+
x,
64+
scale_ones,
65+
x.numel(),
66+
BLOCK_SIZE=BLOCK_SIZE,
67+
)
68+
69+
torch.testing.assert_close(x, y)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s
2+
3+
module {
4+
tt.func public @mask_loop(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%c3_i32 = arith.constant 3 : i32
6+
%c4_i32 = arith.constant 4 : i32
7+
%c1_i32 = arith.constant 1 : i32
8+
%c0_i32 = arith.constant 0 : i32
9+
%c2_i32 = arith.constant 2 : i32
10+
%0 = tt.get_program_id x : i32
11+
%1 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
12+
%2 = tt.get_num_programs x : i32
13+
%3 = arith.muli %2, %c2_i32 : i32
14+
%4 = arith.addi %arg3, %c3_i32 : i32
15+
%5 = arith.divsi %4, %c4_i32 : i32
16+
%6 = arith.muli %0, %c2_i32 : i32
17+
%7 = tt.splat %6 : i32 -> tensor<2xi32>
18+
%8 = arith.addi %7, %1 : tensor<2xi32>
19+
%9 = tt.splat %arg3 : i32 -> tensor<2xi32>
20+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2x!tt.ptr<f32>>
21+
%11 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x!tt.ptr<f32>>
22+
%12 = tt.splat %3 : i32 -> tensor<2xi32>
23+
%13 = scf.for %arg4 = %c0_i32 to %5 step %c1_i32 iter_args(%arg5 = %8) -> (tensor<2xi32>) : i32 {
24+
%14 = arith.cmpi slt, %arg5, %9 : tensor<2xi32>
25+
%15 = tt.addptr %10, %arg5 : tensor<2x!tt.ptr<f32>>, tensor<2xi32>
26+
%16 = tt.load %15, %14 : tensor<2x!tt.ptr<f32>>
27+
%17 = tt.addptr %11, %arg5 : tensor<2x!tt.ptr<f32>>, tensor<2xi32>
28+
tt.store %17, %16, %14 : tensor<2x!tt.ptr<f32>>
29+
%18 = arith.addi %arg5, %12 : tensor<2xi32>
30+
scf.yield %18 : tensor<2xi32>
31+
}
32+
tt.return
33+
}
34+
}
35+
36+
// CHECK: %8 = scf.for %arg4 = %c0_i32 to %5 step %c1_i32 iter_args(%arg5 = %7) -> (index) : i32 {
37+
// CHECK: %9 = tts.make_tptr %arg1
38+
// CHECK: %10 = arith.addi %arg5, %c2 : index
39+
// CHECK: %11 = arith.index_cast %arg3 : i32 to index
40+
// CHECK: %12 = arith.minsi %10, %11 : index
41+
// CHECK: %13 = arith.maxsi %12, %arg5 : index
42+
// CHECK: %14 = arith.subi %13, %arg5 : index
43+
// CHECK: %15 = "tts.load"(%9, %14)

0 commit comments

Comments
 (0)