Skip to content

Commit 72e73a0

Browse files
authored
Support mask analysis against 0, and between two scalars (#202)
Support was added for two new scenarios: `arith.cmpi ge %scalar, %c0`: aka offset comparison to the lower bound of 0. Mask analysis already has an implicit assumption that the beginning of a mask starts at 0, so support was added to allow this case through and assumes that this comparison evaluates to true. `arith.cmpi slt %scalar,_1 %scalar_2`: offset comparison between two scalars. E.g.: ``` %11 = tt.expand_dims %offset %cst_4 = arith.constant dense<324> : tensor<16x1xi64> %23 = arith.cmpi slt, %11, %cst_4 : tensor<16x1xi64> ``` This example is notable in that we cannot take the normal approach of computing the minimum of the lhs and rhs as the new dimension (the lhs offset may be 0). To handle this, a ternary operator is inserted to evaluate the comparison at runtime. If it succeeds, we keep the existing dimensions from the lhs, otherwise we assume nothing should be loaded/stored. This change also adds a dump method to both `MaskState` and `PtrState` as a small QOL improvement.
1 parent a7ffd7d commit 72e73a0

File tree

7 files changed

+218
-25
lines changed

7 files changed

+218
-25
lines changed

include/triton-shared/Analysis/MaskAnalysis.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct MaskState {
5151
OpFoldResult scalar;
5252
const bool useUnsafeMask;
5353

54+
void dump() const;
55+
5456
MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {}
5557

5658
int64_t getRank() const { return dims.size(); }
@@ -118,9 +120,17 @@ struct MaskState {
118120
OpBuilder &builder);
119121

120122
// Operand is the result of cmpi
121-
// Assume only of the dimensions have size > 1. Only support slt for now.
122-
// For that dimension, calculate this new dim as: dim = min(end, value) -
123-
// start
123+
// Assume only one of the dimensions has size > 1. Only support slt/ult, and
124+
// sge against 0 for now. For that dimension, we have three cases:
125+
// 1. Constant comparison with both left and right-hand sides being scalars.
126+
// Calculate this new dim as a compare and select.
127+
// I.e. dim = lhs < rhs ? end : 0
128+
// 2. Left-hand side is not a scalar, and the right-hand side is.
129+
// 2.a. Predicate is slt/ult. Calculate this new dim as:
130+
// dim = max(min(end, value), start) - start
131+
// 2.b. Predicate is sge against 0. Mask analysis already has an
132+
// assumption that the mask starts at 0, so evaluate this to true
133+
// and calculate this new dim as: dim = end
124134
LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc,
125135
OpBuilder &builder);
126136
// Operand is the result of make_range

include/triton-shared/Analysis/OpFoldResultUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/Location.h"
1212
#include "mlir/IR/OpDefinition.h"
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
1314

1415
#include <optional>
1516

@@ -57,6 +58,10 @@ OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
5758

5859
OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
5960
const Location loc, OpBuilder &b);
61+
62+
OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
63+
const arith::CmpIPredicate pred, const OpFoldResult trueVal,
64+
const OpFoldResult falseVal, const Location loc, OpBuilder &b);
6065
} // namespace mlir
6166

6267
#endif

include/triton-shared/AnalysisStructured/PtrAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ struct PtrState {
5959

6060
bool isBlockPtr() const;
6161

62+
void dump() const;
63+
6264
// Process addition of two PtrStates.
6365
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState,
6466
Operation *op, OpBuilder &builder);

lib/Analysis/MaskAnalysis.cpp

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc,
265265
return success();
266266
}
267267

268+
void MaskState::dump() const {
269+
llvm::dbgs() << "start: " << start << "\n";
270+
llvm::dbgs() << "end: " << end << "\n";
271+
llvm::dbgs() << "scalar: " << scalar << "\n";
272+
llvm::dbgs() << "useUnsafeMask: " << useUnsafeMask << "\n";
273+
llvm::dbgs() << "dims: ";
274+
for (auto dim : dims)
275+
llvm::dbgs() << "\t" << dim << "\n";
276+
llvm::dbgs() << "\n";
277+
}
278+
268279
LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc,
269280
OpBuilder &builder) {
270281
assert(this->isEmpty());
@@ -308,7 +319,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
308319
assert(this->isEmpty());
309320

310321
if (cmpOp.getPredicate() != arith::CmpIPredicate::slt &&
311-
cmpOp.getPredicate() != arith::CmpIPredicate::ult) {
322+
cmpOp.getPredicate() != arith::CmpIPredicate::ult &&
323+
cmpOp.getPredicate() != arith::CmpIPredicate::sge) {
312324
InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi";
313325
return failure();
314326
}
@@ -321,9 +333,17 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
321333
if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder)))
322334
return failure();
323335

324-
assert((!lhsState.scalar && rhsState.scalar) && "Unsupported cmpi scenario");
336+
// We only support sge against 0 for lower bounds. Dims already has an
337+
// implicit assumption that the lower bound is 0, so if we see this, assume
338+
// the comparison evaluates to true.
339+
if (cmpOp.getPredicate() == arith::CmpIPredicate::sge
340+
&& !(rhsState.scalar && hasConstZero(rhsState.scalar))) {
341+
InFlightDiagnostic diag = emitError(loc)
342+
<< "Unsupported cmpi with rhs not equal to 0";
343+
return failure();
344+
}
325345

326-
int32_t cmpDim = -1;
346+
int32_t cmpDim = lhsState.scalar && rhsState.scalar ? 0 : -1;
327347
for (int32_t i = 0; i < lhsState.getRank(); i++) {
328348
auto dimIntAttr = getIntAttr(lhsState.dims[i]);
329349
if (!dimIntAttr || dimIntAttr.value() != 1) {
@@ -339,22 +359,42 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
339359
assert(cmpDim != -1 &&
340360
"Unexpected case where no dimension has size larger than 1");
341361

342-
// Important:
343-
// In the case where the values we are loading are entirely masked off like
344-
// the following:
345-
//
346-
// ---|-------|-----------|
347-
// ^ ^ ^
348-
// scalar start end
349-
//
350-
// newEnd = min(end, scalar) = scalar
351-
// Now scalar < start, so simply doing dim = newEnd - start is incorrect.
352-
//
353-
// The correct formula is to optionally move `newDim` back to `start` using
354-
// max(newEnd, start).
355-
auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder);
356-
newEnd = maxOFRs(newEnd, lhsState.start, loc, builder);
357-
auto newDim = subOFRs(newEnd, lhsState.start, loc, builder);
362+
OpFoldResult newDim;
363+
if (lhsState.scalar) {
364+
assert(rhsState.scalar && "Unexpected case where rhs is not a scalar");
365+
// If both lhs and rhs are scalars, we can't just derive the dimension of
366+
// the mask as the minimum value: lhs/rhs could be 0 and then we don't
367+
// load/store anything.
368+
//
369+
// Instead treat the comparison as a scalar that determines if anything
370+
// should be loaded/stored by inserting a comparison + select:
371+
// dim = lhs < rhs ? lhs.dim : 0
372+
newDim = compareOFRs(lhsState.scalar, rhsState.scalar, cmpOp.getPredicate(),
373+
lhsState.dims[cmpDim], builder.getIndexAttr(0),
374+
loc, builder);
375+
} else if (cmpOp.getPredicate() == arith::CmpIPredicate::slt ||
376+
cmpOp.getPredicate() == arith::CmpIPredicate::ult) {
377+
// Important:
378+
// In the case where the values we are loading are entirely masked off like
379+
// the following:
380+
//
381+
// ---|-------|-----------|
382+
// ^ ^ ^
383+
// scalar start end
384+
//
385+
// newEnd = min(end, scalar) = scalar
386+
// Now scalar < start, so simply doing dim = newEnd - start is incorrect.
387+
//
388+
// The correct formula is to optionally move `newDim` back to `start` using
389+
// max(newEnd, start).
390+
auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder);
391+
newEnd = maxOFRs(newEnd, lhsState.start, loc, builder);
392+
newDim = subOFRs(newEnd, lhsState.start, loc, builder);
393+
} else {
394+
assert(cmpOp.getPredicate() == arith::CmpIPredicate::sge && rhsState.scalar
395+
&& hasConstZero(rhsState.scalar));
396+
newDim = lhsState.dims[cmpDim];
397+
}
358398

359399
for (int32_t i = 0; i < lhsState.getRank(); i++) {
360400
if (i == cmpDim)

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,43 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
245245
return maxOp.getResult();
246246
}
247247

248+
OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
249+
const arith::CmpIPredicate pred, const OpFoldResult trueOFR,
250+
const OpFoldResult falseOFR, const Location loc, OpBuilder &b) {
251+
auto lhsIntAttr = getIntAttr(lhs);
252+
auto rhsIntAttr = getIntAttr(rhs);
253+
254+
// both lhs and rhs are constants, return the result directly
255+
if (lhsIntAttr && rhsIntAttr) {
256+
switch (pred) {
257+
case arith::CmpIPredicate::eq:
258+
return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR;
259+
case arith::CmpIPredicate::ne:
260+
return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR;
261+
case arith::CmpIPredicate::slt:
262+
case arith::CmpIPredicate::ult:
263+
return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR;
264+
case arith::CmpIPredicate::sle:
265+
case arith::CmpIPredicate::ule:
266+
return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR;
267+
case arith::CmpIPredicate::sgt:
268+
case arith::CmpIPredicate::ugt:
269+
return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR;
270+
case arith::CmpIPredicate::sge:
271+
case arith::CmpIPredicate::uge:
272+
return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR;
273+
default:
274+
llvm_unreachable("Unsupported predicate");
275+
}
276+
}
277+
278+
auto lhsValue = ofrToIndexValue(lhs, loc, b);
279+
auto rhsValue = ofrToIndexValue(rhs, loc, b);
280+
auto trueValue = ofrToIndexValue(trueOFR, loc, b);
281+
auto falseValue = ofrToIndexValue(falseOFR, loc, b);
282+
283+
auto cmpOp = b.create<arith::CmpIOp>(loc, pred, lhsValue, rhsValue);
284+
auto selectOp = b.create<arith::SelectOp>(loc, cmpOp, trueValue, falseValue);
285+
return selectOp.getResult();
286+
}
248287
} // namespace mlir

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,28 @@ LogicalResult PtrState::addState(const PtrState &lhsState,
251251
return success();
252252
}
253253

254+
void PtrState::dump() const {
255+
llvm::dbgs() << "PtrState: ";
256+
if (source) {
257+
llvm::dbgs() << "source: " << source << "\n";
258+
}
259+
if (scalar) {
260+
llvm::dbgs() << "scalar: " << scalar << "\n";
261+
}
262+
263+
llvm::dbgs() << "offsets: ";
264+
llvm::interleave(offsets, llvm::dbgs(), "\n");
265+
llvm::dbgs() << "\nstrides: ";
266+
llvm::interleave(strides, llvm::dbgs(), "\n");
267+
llvm::dbgs() << "\nsizes: ";
268+
llvm::interleave(sizes, llvm::dbgs(), "\n");
269+
llvm::dbgs() << "\nshape: ";
270+
llvm::interleave(shape, llvm::dbgs(), "\n");
271+
llvm::dbgs() << "\norder: ";
272+
llvm::interleave(order, llvm::dbgs(), "\n");
273+
llvm::dbgs() << "\n";
274+
}
275+
254276
LogicalResult PtrState::mulState(const PtrState &lhsState,
255277
const PtrState &rhsState, Operation *op,
256278
OpBuilder &builder) {
@@ -265,9 +287,6 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
265287
return failure();
266288
}
267289

268-
assert(!(lhsState.scalar && rhsState.scalar) &&
269-
"do not expect to see both lhs and rhs are scalars");
270-
271290
// currently do not support both tensors are effectively non-scalar
272291
if (!lhsState.scalar && !rhsState.scalar) {
273292
op->emitRemark(
@@ -283,6 +302,11 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
283302
std::swap(lhs, rhs);
284303
}
285304

305+
if (lhsState.scalar && rhsState.scalar) {
306+
scalar = builder.create<arith::MulIOp>(
307+
loc, lhsState.scalar, rhsState.scalar);
308+
}
309+
286310
for (uint64_t i = 0; i < lhs->sizes.size(); i++) {
287311
OpFoldResult newOffset =
288312
mulOFRValue(lhs->offsets[i], rhs->scalar, loc, builder);
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: triton-shared-opt --triton-to-structured --split-input-file %s | FileCheck %s
2+
3+
// These tests check that loads/stores that exhibit a cmp ge against 0 work
4+
// correctly with the pointer analysis pass
5+
6+
// Example of the triton kernel that generates the loads/stores with cmp ge 0.
7+
// The boundary_check fields of the load/stores, along with preprocessing the
8+
// kernel through --triton-rewrite-tensor-pointer before calling the
9+
// --triton-to-structured pass results in those cmp ge 0 instructions.
10+
//
11+
// def kernel(in_ptr0, out_ptr0, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
12+
// yoffset = tl.program_id(1) * YBLOCK
13+
// xoffset = tl.program_id(0) * XBLOCK
14+
// tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[16640, 10],
15+
// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK],
16+
// order=[1, 0], offsets=[xoffset, yoffset]),
17+
// boundary_check=[0, 1])
18+
// tl.store(tl.make_block_ptr(out_ptr0, shape=[16640, 10],
19+
// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK],
20+
// order=[1, 0], offsets=[xoffset, yoffset]),
21+
// tl.broadcast_to(tmp0, [XBLOCK, YBLOCK]).to(tl.float16),
22+
// boundary_check=[0, 1])
23+
24+
tt.func public @test_masked_load(%arg0: !tt.ptr<f16>) -> tensor<16x16xf16> {
25+
%cst = arith.constant dense<0> : tensor<1x16xi64>
26+
%c16_i32 = arith.constant 16 : i32
27+
%0 = tt.get_program_id y : i32
28+
%1 = arith.muli %0, %c16_i32 : i32
29+
%2 = arith.extsi %1 : i32 to i64
30+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>>
31+
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
32+
%5 = arith.extsi %4 : tensor<16xi32> to tensor<16xi64>
33+
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
34+
%7 = tt.broadcast %6 : tensor<16x1xi64> -> tensor<16x16xi64>
35+
%8 = tt.addptr %3, %7 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64>
36+
%9 = tt.splat %2 : i64 -> tensor<16xi64>
37+
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
38+
%11 = arith.extsi %10 : tensor<16xi32> to tensor<16xi64>
39+
%12 = arith.addi %9, %11 : tensor<16xi64>
40+
%13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64>
41+
%14 = arith.cmpi sge, %13, %cst : tensor<1x16xi64>
42+
%15 = tt.broadcast %14 : tensor<1x16xi1> -> tensor<16x16xi1>
43+
%16 = tt.load %8, %15 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
44+
tt.return %16 : tensor<16x16xf16>
45+
}
46+
47+
// CHECK: tt.func public @test_masked_load([[arg0_:%.+]]: !tt.ptr<f16>) -> tensor<16x16xf16> {
48+
// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [16, 16], strides: [1, 0], offsets: [0, 0], shape: [0, 0], order: [] : <f16> to tensor<16x16x!tt.ptr<f16>>
49+
// CHECK: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64: 16, 16>}> : (tensor<16x16x!tt.ptr<f16>>) -> tensor<16x16xf16>
50+
// CHECK: }
51+
52+
// -----
53+
54+
tt.func public @test_masked_store(%arg0: !tt.ptr<f16>) {
55+
%cst = arith.constant dense<0> : tensor<16x1xi64>
56+
%cst_0 = arith.constant dense<1.500000e+01> : tensor<16x16xf16>
57+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>>
58+
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
59+
%2 = arith.extsi %1 : tensor<16xi32> to tensor<16xi64>
60+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
61+
%4 = tt.broadcast %3 : tensor<16x1xi64> -> tensor<16x16xi64>
62+
%5 = tt.addptr %0, %4 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64>
63+
%6 = arith.cmpi sge, %3, %cst : tensor<16x1xi64>
64+
%7 = tt.broadcast %6 : tensor<16x1xi1> -> tensor<16x16xi1>
65+
tt.store %5, %cst_0, %7 : tensor<16x16x!tt.ptr<f16>>
66+
tt.return
67+
}
68+
69+
// CHECK: tt.func public @test_masked_store([[arg0_:%.+]]: !tt.ptr<f16>) {
70+
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.500000e+01> : tensor<16x16xf16>
71+
// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [16, 16], strides: [1, 0], offsets: [0, 0], shape: [0, 0], order: [] : <f16> to tensor<16x16x!tt.ptr<f16>>
72+
// CHECK: "tts.store"([[VAR_0_]], [[VAR_cst_]]) <{static_mask_dims = array<i64: 16, 16>}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xf16>) -> ()
73+
// CHECK: }

0 commit comments

Comments
 (0)