Skip to content

Commit 6b0f5a0

Browse files
lxwangrucLinxiao Wang
andauthored
Fix MaskAnalysis bool type workaround (#346)
In MaskAnalysis, a Boolean type shouldn't be passed around as Index. This replaces the workaround added in PR #318. We use the following two tests as baseline, they needed the workaround to pass before and now passing with our fix. ``` test/Conversion/TritonToStructured/mask_ld_st_scalar_dim.mlir python/examples/test_mask.py ``` Co-authored-by: Linxiao Wang <[email protected]>
1 parent 8f3b43e commit 6b0f5a0

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

include/triton-shared/Analysis/OpFoldResultUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ std::optional<int64_t> getIntAttr(const OpFoldResult ofr);
2727
// attribute or a constant value.
2828
bool hasConstZero(const OpFoldResult ofr);
2929

30+
// Cast OpFoldResult to Value.
31+
Value ofrToValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);
32+
3033
// Create a value of index type if necessary from an OpFoldResult.
3134
Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);
3235

lib/Analysis/MaskAnalysis.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,13 @@ LogicalResult MaskState::parseConstant(arith::ConstantOp constOp,
302302
LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc,
303303
OpBuilder &builder) {
304304
assert(this->isEmpty());
305-
auto castOp =
306-
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), scalar);
307-
this->scalar = castOp.getResult();
305+
if (scalar.getType().isInteger(1)) {
306+
this->scalar = scalar;
307+
} else {
308+
auto castOp =
309+
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), scalar);
310+
this->scalar = castOp.getResult();
311+
}
308312
return success();
309313
}
310314

lib/Analysis/OpFoldResultUtils.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ bool hasConstZero(const OpFoldResult ofr) {
5252
return false;
5353
}
5454

55+
Value ofrToValue(const OpFoldResult ofr, const Location loc, OpBuilder &b) {
56+
if (Value val = dyn_cast<Value>(ofr)) {
57+
return val;
58+
}
59+
60+
auto attr = dyn_cast<Attribute>(ofr);
61+
auto typedAttr = dyn_cast<TypedAttr>(attr);
62+
return b.create<arith::ConstantOp>(loc, typedAttr);
63+
}
64+
5565
Value ofrToIndexValue(const OpFoldResult ofr, const Location loc,
5666
OpBuilder &b) {
5767
if (Value val = dyn_cast<Value>(ofr)) {
@@ -344,17 +354,8 @@ OpFoldResult selectOFRs(const OpFoldResult condOFR, const OpFoldResult trueOFR,
344354
OpBuilder &b) {
345355
auto trueValue = ofrToIndexValue(trueOFR, loc, b);
346356
auto falseValue = ofrToIndexValue(falseOFR, loc, b);
347-
auto condValue = ofrToIndexValue(condOFR, loc, b);
348-
349-
// Ideally we should not be passing around everything as index type since mask
350-
// analysis can come across i1 values, but that improvement is being left for
351-
// future work. For now we just unwrap an index back into it's i1 value if
352-
// necessary.
353-
if (!condValue.getType().isInteger(1)) {
354-
assert(condValue.getDefiningOp<arith::IndexCastOp>());
355-
condValue = condValue.getDefiningOp<arith::IndexCastOp>().getOperand();
356-
assert(condValue.getType().isInteger(1));
357-
}
357+
auto condValue = ofrToValue(condOFR, loc, b);
358+
assert(condValue.getType().isInteger(1) && "Condition for selectOp must be a bool type");
358359

359360
auto selectOp =
360361
b.create<arith::SelectOp>(loc, condValue, trueValue, falseValue);

0 commit comments

Comments
 (0)