Skip to content

Commit eaeb554

Browse files
authored
[MaskAnalysis] Fix implement of scalar conjunction in mask analysis. (#318)
The original implementation of conjunction with scalar mistakenly assumed that the scalar argument would use it's dimension as the mask point. This does not work because scalar arguments should be a binary and not a pivot point. To support scalar arguments, we must either use the mask argument of the conjunction or completely zero out the mask. This PR changes the code generation so that select statements embedded this pattern.
1 parent 38b93c1 commit eaeb554

File tree

5 files changed

+126
-53
lines changed

5 files changed

+126
-53
lines changed

include/triton-shared/Analysis/OpFoldResultUtils.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H
99
#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H
1010

11+
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/IR/Location.h"
1213
#include "mlir/IR/OpDefinition.h"
13-
#include "mlir/Dialect/Arith/IR/Arith.h"
1414

1515
#include <optional>
1616

@@ -55,17 +55,23 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
5555
// result is an Integer Attribtue. Otherwise, insert the arith.muli
5656
// instruction if needed and use its result Value.
5757
OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
58-
const Location loc, OpBuilder &b);
58+
const Location loc, OpBuilder &b);
5959

6060
OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
6161
const Location loc, OpBuilder &b);
6262

6363
OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
6464
const Location loc, OpBuilder &b);
6565

66+
OpFoldResult selectOFRs(const OpFoldResult cond, const OpFoldResult trueOFR,
67+
const OpFoldResult falseOFR, const Location loc,
68+
OpBuilder &b);
69+
6670
OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
67-
const arith::CmpIPredicate pred, const OpFoldResult trueVal,
68-
const OpFoldResult falseVal, const Location loc, OpBuilder &b);
71+
const arith::CmpIPredicate pred,
72+
const OpFoldResult trueVal,
73+
const OpFoldResult falseVal, const Location loc,
74+
OpBuilder &b);
6975
} // namespace mlir
7076

7177
#endif

lib/Analysis/MaskAnalysis.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,26 +219,42 @@ LogicalResult MaskState::addStates(const MaskState &lhsState,
219219
LogicalResult MaskState::minStateScalar(const MaskState &lhsState,
220220
const MaskState &rhsState, Location loc,
221221
OpBuilder &builder) {
222+
// Conjunction where both sides are scalar should not be done after splats. We
223+
// should ensure that code generation pushes the splat as late as possible.
222224
if (lhsState.scalar && rhsState.scalar) {
223-
dims.push_back(minOFRs(lhsState.dims[0], rhsState.dims[0], loc, builder));
224-
} else if (lhsState.scalar) {
225-
for (uint32_t i = 0; i < rhsState.getRank(); i++) {
226-
auto lhsDim = lhsState.dims[0];
227-
auto rhsDim = rhsState.dims[i];
228-
dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder));
229-
}
230-
} else if (rhsState.scalar) {
231-
for (uint32_t i = 0; i < lhsState.getRank(); i++) {
232-
auto lhsDim = lhsState.dims[i];
233-
auto rhsDim = rhsState.dims[0];
234-
dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder));
235-
}
236-
} else {
225+
InFlightDiagnostic diag =
226+
emitError(loc) << "Unexpected case where both lhs and rhs are scalars";
227+
return failure();
228+
}
229+
230+
// Caller should ensure that at least one side is scalar.
231+
if (!lhsState.scalar && !rhsState.scalar) {
237232
InFlightDiagnostic diag =
238233
emitError(loc)
239234
<< "Unexpected case where both lhs and rhs are not scalars";
240235
return failure();
241236
}
237+
238+
// If we see a scalar condition in a conjunction with a mask, this means we
239+
// are either going to take the mask dimension or take nothing at all. To do
240+
// that we use a select on the scalar value with the mask dimension in the
241+
// true case and zero in the false case.
242+
//
243+
// Example:
244+
// def kernel(..., index: i32, ...):
245+
// ...
246+
// offs = tl.arange(0, 8)
247+
// mask = offs < 4
248+
// scalar = index < 4
249+
// ... = tl.load(some_ptr, mask=scalar & mask, other=0)
250+
auto &scalarState = lhsState.scalar ? lhsState : rhsState;
251+
auto &nonScalarState = lhsState.scalar ? rhsState : lhsState;
252+
for (uint32_t i = 0; i < nonScalarState.getRank(); i++) {
253+
auto nonScalarDim = nonScalarState.dims[i];
254+
dims.push_back(selectOFRs(scalarState.scalar, nonScalarDim,
255+
builder.getZeroAttr(builder.getIndexType()), loc,
256+
builder));
257+
}
242258
return success();
243259
}
244260

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
104104

105105
Value v = dyn_cast<Value>(ofr);
106106
if (!v)
107-
v = b.create<arith::ConstantOp>(loc, cast<IntegerAttr>(cast<Attribute>(ofr)));
107+
v = b.create<arith::ConstantOp>(loc,
108+
cast<IntegerAttr>(cast<Attribute>(ofr)));
108109

109110
Type ty = v.getType();
110111
if (targetTy == ty)
@@ -126,7 +127,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
126127
// This path is for case like:
127128
// input_ptr + (row_indices[:, None] + row_offsets[:,None] % mod_offset) *
128129
// stride_m + col_offsets[None, :] * stride_n
129-
// The modulo will be in shape of [ROW_SIZE, 1] while row_indices is in shape of [ROW_SIZE,].
130+
// The modulo will be in shape of [ROW_SIZE, 1] while row_indices is in
131+
// shape of [ROW_SIZE,].
130132
LLVM_DEBUG({
131133
llvm::dbgs() << "Reshaping ";
132134
shapedTy.dump();
@@ -135,14 +137,15 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
135137
});
136138
SmallVector<Value> shapeValues;
137139
for (auto dim : targetShapedTy.getShape()) {
138-
shapeValues.push_back(b.create<arith::ConstantOp>(
139-
loc, b.getIndexAttr(dim)));
140+
shapeValues.push_back(
141+
b.create<arith::ConstantOp>(loc, b.getIndexAttr(dim)));
140142
}
141143
RankedTensorType targetShapeTensorTy = RankedTensorType::get(
142144
targetShapedTy.getShape().size(), b.getIndexType());
143145
auto shapeTensor = b.create<tensor::FromElementsOp>(
144146
loc, targetShapeTensorTy, shapeValues);
145-
return b.create<triton::ReshapeOp>(loc, targetTy, v, shapeTensor).getResult();
147+
return b.create<triton::ReshapeOp>(loc, targetTy, v, shapeTensor)
148+
.getResult();
146149
}
147150
if (isa<IndexType>(targetEltTy) || isa<IndexType>(eltTy)) {
148151
assert((isa<IntegerType>(targetEltTy) || isa<IntegerType>(eltTy)) &&
@@ -228,7 +231,7 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
228231
}
229232

230233
OpFoldResult mulOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
231-
const Location loc, OpBuilder &b) {
234+
const Location loc, OpBuilder &b) {
232235
auto lhsIntAttr = getIntAttr(lhs);
233236
auto rhsIntAttr = getIntAttr(rhs);
234237

@@ -336,44 +339,65 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
336339
return maxOp.getResult();
337340
}
338341

342+
OpFoldResult selectOFRs(const OpFoldResult condOFR, const OpFoldResult trueOFR,
343+
const OpFoldResult falseOFR, const Location loc,
344+
OpBuilder &b) {
345+
auto trueValue = ofrToIndexValue(trueOFR, loc, b);
346+
auto falseValue = ofrToIndexValue(falseOFR, loc, b);
347+
auto condValue = ofrToIndexValue(condOFR, loc, b);
348+
349+
// Ideally we should not be passing around everything as index type since mask
350+
// analysis can come across i1 values, but that improvement is being left for
351+
// future work. For now we just unwrap an index back into it's i1 value if
352+
// necessary.
353+
if (!condValue.getType().isInteger(1)) {
354+
assert(condValue.getDefiningOp<arith::IndexCastOp>());
355+
condValue = condValue.getDefiningOp<arith::IndexCastOp>().getOperand();
356+
assert(condValue.getType().isInteger(1));
357+
}
358+
359+
auto selectOp =
360+
b.create<arith::SelectOp>(loc, condValue, trueValue, falseValue);
361+
return selectOp.getResult();
362+
}
363+
339364
OpFoldResult compareOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
340-
const arith::CmpIPredicate pred, const OpFoldResult trueOFR,
341-
const OpFoldResult falseOFR, const Location loc, OpBuilder &b) {
365+
const arith::CmpIPredicate pred,
366+
const OpFoldResult trueOFR,
367+
const OpFoldResult falseOFR, const Location loc,
368+
OpBuilder &b) {
342369
auto lhsIntAttr = getIntAttr(lhs);
343370
auto rhsIntAttr = getIntAttr(rhs);
344371

345372
// both lhs and rhs are constants, return the result directly
346373
if (lhsIntAttr && rhsIntAttr) {
347374
switch (pred) {
348-
case arith::CmpIPredicate::eq:
349-
return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR;
350-
case arith::CmpIPredicate::ne:
351-
return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR;
352-
case arith::CmpIPredicate::slt:
353-
case arith::CmpIPredicate::ult:
354-
return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR;
355-
case arith::CmpIPredicate::sle:
356-
case arith::CmpIPredicate::ule:
357-
return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR;
358-
case arith::CmpIPredicate::sgt:
359-
case arith::CmpIPredicate::ugt:
360-
return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR;
361-
case arith::CmpIPredicate::sge:
362-
case arith::CmpIPredicate::uge:
363-
return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR;
364-
default:
365-
llvm_unreachable("Unsupported predicate");
375+
case arith::CmpIPredicate::eq:
376+
return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR;
377+
case arith::CmpIPredicate::ne:
378+
return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR;
379+
case arith::CmpIPredicate::slt:
380+
case arith::CmpIPredicate::ult:
381+
return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR;
382+
case arith::CmpIPredicate::sle:
383+
case arith::CmpIPredicate::ule:
384+
return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR;
385+
case arith::CmpIPredicate::sgt:
386+
case arith::CmpIPredicate::ugt:
387+
return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR;
388+
case arith::CmpIPredicate::sge:
389+
case arith::CmpIPredicate::uge:
390+
return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR;
391+
default:
392+
llvm_unreachable("Unsupported predicate");
366393
}
367394
}
368395

369396
auto lhsValue = ofrToIndexValue(lhs, loc, b);
370397
auto rhsValue = ofrToIndexValue(rhs, loc, b);
371-
auto trueValue = ofrToIndexValue(trueOFR, loc, b);
372-
auto falseValue = ofrToIndexValue(falseOFR, loc, b);
373398

374399
auto cmpOp = b.create<arith::CmpIOp>(loc, pred, lhsValue, rhsValue);
375-
auto selectOp = b.create<arith::SelectOp>(loc, cmpOp, trueValue, falseValue);
376-
return selectOp.getResult();
400+
return selectOFRs(cmpOp.getResult(), trueOFR, falseOFR, loc, b);
377401
}
378402

379403
} // namespace mlir

python/examples/test_mask.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,29 @@ def test(in0, out0):
3737
print(input)
3838
print(output)
3939
torch.testing.assert_close(output, torch.tensor([-1, -1, -1, -1, -2, -2, -2, -2], device=device, dtype=torch.int32))
40+
41+
42+
def test_mask_with_scalar_in_conjunction(device):
43+
if device == 'cpu':
44+
triton.runtime.driver.set_active(CPUDriver())
45+
46+
@triton.jit
47+
def kernel(in0, out0, mask, value):
48+
offs = tl.arange(0, 8)
49+
out_offs = tl.arange(0, 8)
50+
a = tl.load(in0 + offs, mask=(value < 5) & (offs < mask), other=-1)
51+
tl.store(out0 + out_offs, a)
52+
53+
# Test scalar mask evaluate to True
54+
SIZE = 8
55+
input = torch.arange(0, SIZE, device=device, dtype=torch.int32)
56+
output = torch.full((SIZE,), -2, device=device, dtype=torch.int32)
57+
kernel[(1,)](input, output, 4, 3)
58+
torch.testing.assert_close(output, torch.tensor([0, 1, 2, 3, -1, -1, -1, -1], device=device, dtype=torch.int32))
59+
60+
# Test scalar mask evaluate to False
61+
SIZE = 8
62+
input = torch.arange(0, SIZE, device=device, dtype=torch.int32)
63+
output = torch.full((SIZE,), -2, device=device, dtype=torch.int32)
64+
kernel[(1,)](input, output, 4, 8)
65+
torch.testing.assert_close(output, torch.tensor([-1, -1, -1, -1, -1, -1, -1, -1], device=device, dtype=torch.int32))

test/Conversion/TritonToStructured/mask_ld_st_scalar_dim.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s
1+
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s
22

33
module {
44
tt.func @mask_ld_st_scalar(
@@ -38,10 +38,11 @@ module {
3838
}
3939
}
4040

41-
// CHECK: %{{.*}} = "tts.load"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0>, static_mask_dims = array<i64: -9223372036854775808, 1>}> : (tensor<2x1x!tt.ptr<f32>>, index) -> tensor<2x1xf32>
42-
// CHECK: %{{.*}} = "tts.load"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0>, static_mask_dims = array<i64: -9223372036854775808, 1>}> : (tensor<2x1x!tt.ptr<f32>>, index) -> tensor<2x1xf32>
43-
// CHECK: "tts.store"(%{{.*}}, %{{.*}}, %{{.*}}) <{static_mask_dims = array<i64: -9223372036854775808, 1>}> : (tensor<2x1x!tt.ptr<f32>>, tensor<2x1xf32>, index) -> ()
44-
// CHECK: "tts.store"(%{{.*}}, %{{.*}}, %{{.*}}) <{static_mask_dims = array<i64: -9223372036854775808, 1>}> : (tensor<2x1x!tt.ptr<f32>>, tensor<2x1xf32>, index) -> ()
41+
// CHECK: %{{.*}} = "tts.load"(%{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 2, 0>, static_mask_dims = array<i64: -9223372036854775808, -9223372036854775808>}> : (tensor<2x1x!tt.ptr<f32>>, index, index) -> tensor<2x1xf32>
42+
// CHECK: %{{.*}} = "tts.load"(%{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 2, 0>, static_mask_dims = array<i64: -9223372036854775808, -9223372036854775808>}> : (tensor<2x1x!tt.ptr<f32>>, index, index) -> tensor<2x1xf32>
43+
// CHECK: "tts.store"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{static_mask_dims = array<i64: -9223372036854775808, -9223372036854775808>}> : (tensor<2x1x!tt.ptr<f32>>, tensor<2x1xf32>, index, index) -> ()
44+
// CHECK: "tts.store"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{static_mask_dims = array<i64: -9223372036854775808, -9223372036854775808>}> : (tensor<2x1x!tt.ptr<f32>>, tensor<2x1xf32>, index, index) -> ()
45+
4546

4647
// Original Triton Function:
4748
// def test_masked_ld_st(

0 commit comments

Comments
 (0)