Skip to content

Commit 6d10d94

Browse files
committed
Fix incorrect trunc of constants in SelectOfExt
1 parent ded532c commit 6d10d94

File tree

2 files changed

+125
-14
lines changed

2 files changed

+125
-14
lines changed

lib/polygeist/Ops.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,7 +2295,7 @@ class OrIExcludedMiddle final : public OpRewritePattern<arith::OrIOp> {
22952295
}
22962296
};
22972297

2298-
class SelectI1Ext final : public OpRewritePattern<arith::SelectOp> {
2298+
class SelectOfExt final : public OpRewritePattern<arith::SelectOp> {
22992299
public:
23002300
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
23012301

@@ -2304,30 +2304,48 @@ class SelectI1Ext final : public OpRewritePattern<arith::SelectOp> {
23042304
auto ty = op.getType().dyn_cast<IntegerType>();
23052305
if (!ty)
23062306
return failure();
2307-
if (ty.getWidth() == 1)
2308-
return failure();
23092307
IntegerAttr lhs, rhs;
23102308
Value lhs_v = nullptr, rhs_v = nullptr;
2309+
unsigned lhs_w = 0, rhs_w = 0;
23112310
if (auto ext = op.getTrueValue().getDefiningOp<arith::ExtUIOp>()) {
23122311
lhs_v = ext.getIn();
2313-
if (lhs_v.getType().cast<IntegerType>().getWidth() != 1)
2314-
return failure();
2312+
lhs_w = lhs_v.getType().cast<IntegerType>().getWidth();
23152313
} else if (matchPattern(op.getTrueValue(), m_Constant(&lhs))) {
2316-
} else
2314+
} else {
23172315
return failure();
2316+
}
23182317

23192318
if (auto ext = op.getFalseValue().getDefiningOp<arith::ExtUIOp>()) {
23202319
rhs_v = ext.getIn();
2321-
if (rhs_v.getType().cast<IntegerType>().getWidth() != 1)
2322-
return failure();
2320+
rhs_w = rhs_v.getType().cast<IntegerType>().getWidth();
23232321
} else if (matchPattern(op.getFalseValue(), m_Constant(&rhs))) {
2324-
} else
2322+
} else {
23252323
return failure();
2324+
}
23262325

2327-
if (!lhs_v)
2328-
lhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), lhs.getInt(), 1);
2329-
if (!rhs_v)
2330-
rhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), rhs.getInt(), 1);
2326+
// No ext's
2327+
if (!lhs_v && !rhs_v)
2328+
return failure();
2329+
2330+
auto fitsIn = [&](auto i, int width) {
2331+
// if there is nothing in the bits that would be discarded
2332+
return !((~((static_cast<typeof i>(1) << width) - 1)) & i);
2333+
};
2334+
2335+
// Truncate the other non-extended const but only if the original constant
2336+
// fits in the new width
2337+
if (!lhs_v && fitsIn(lhs.getInt(), rhs_w)) {
2338+
assert(rhs_v);
2339+
lhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), lhs.getInt(), rhs_w);
2340+
}
2341+
if (!rhs_v && fitsIn(rhs.getInt(), lhs_w)) {
2342+
assert(lhs_v);
2343+
rhs_v = rewriter.create<ConstantIntOp>(op.getLoc(), rhs.getInt(), lhs_w);
2344+
}
2345+
2346+
// If we werent able to truncate the const
2347+
if (!(lhs_v && rhs_v))
2348+
return failure();
23312349

23322350
rewriter.replaceOpWithNewOp<ExtUIOp>(
23332351
op, op.getType(),
@@ -5526,7 +5544,7 @@ static llvm::cl::opt<bool>
55265544
void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
55275545
MLIRContext *context) {
55285546
results.insert<
5529-
TypeAlignCanonicalize, OrIExcludedMiddle, SelectI1Ext, UndefProp<ExtUIOp>,
5547+
TypeAlignCanonicalize, OrIExcludedMiddle, SelectOfExt, UndefProp<ExtUIOp>,
55305548
UndefProp<ExtSIOp>, UndefProp<TruncIOp>, CmpProp, UndefCmpProp,
55315549
AlwaysAllocaScopeHoister<memref::AllocaScopeOp>,
55325550
AlwaysAllocaScopeHoister<scf::ForOp>,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s
2+
module {
3+
func.func @foo(%arg0: i1) -> i32 {
4+
%c512_i32 = arith.constant 512 : i32
5+
%c16_i32 = arith.constant 16 : i32
6+
%1 = arith.select %arg0, %c512_i32, %c16_i32 : i32
7+
return %1 : i32
8+
}
9+
func.func @foo1(%arg0: i1, %c1_i32: i32, %c2_i32: i32) -> i64 {
10+
%c1_i64 = arith.extui %c1_i32 : i32 to i64
11+
%c2_i64 = arith.extui %c2_i32 : i32 to i64
12+
%1 = arith.select %arg0, %c1_i64, %c2_i64 : i64
13+
return %1 : i64
14+
}
15+
func.func @foo2(%arg0: i1, %c1_i64: i64, %c2_i32: i32) -> i64 {
16+
%c2_i64 = arith.extui %c2_i32 : i32 to i64
17+
%1 = arith.select %arg0, %c1_i64, %c2_i64 : i64
18+
return %1 : i64
19+
}
20+
func.func @foo3(%arg0: i1, %c1_i64: i1, %i8: i8) -> i32 {
21+
%c255 = arith.constant 255 : i32
22+
%i32 = arith.extui %i8 : i8 to i32
23+
%1 = arith.select %arg0, %c255, %i32 : i32
24+
return %1 : i32
25+
}
26+
func.func @foo3.1() -> i8 {
27+
%c255 = arith.constant 255 : i32
28+
%i8 = arith.trunci %c255 : i32 to i8
29+
return %i8 : i8
30+
}
31+
func.func @foo4(%arg0: i1, %c1_i64: i1, %i8: i8) -> i32 {
32+
%c256 = arith.constant 256 : i32
33+
%i32 = arith.extui %i8 : i8 to i32
34+
%1 = arith.select %arg0, %c256, %i32 : i32
35+
return %1 : i32
36+
}
37+
38+
}
39+
40+
// CHECK-LABEL: func.func @foo(
41+
// CHECK-SAME: %[[VAL_0:.*]]: i1) -> i32 {
42+
// CHECK: %[[VAL_1:.*]] = arith.constant 512 : i32
43+
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : i32
44+
// CHECK: %[[VAL_3:.*]] = arith.select %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : i32
45+
// CHECK: return %[[VAL_3]] : i32
46+
// CHECK: }
47+
48+
// CHECK-LABEL: func.func @foo1(
49+
// CHECK-SAME: %[[VAL_0:.*]]: i1,
50+
// CHECK-SAME: %[[VAL_1:.*]]: i32,
51+
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i64 {
52+
// CHECK: %[[VAL_3:.*]] = arith.select %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : i32
53+
// CHECK: %[[VAL_4:.*]] = arith.extui %[[VAL_3]] : i32 to i64
54+
// CHECK: return %[[VAL_4]] : i64
55+
// CHECK: }
56+
57+
// CHECK-LABEL: func.func @foo2(
58+
// CHECK-SAME: %[[VAL_0:.*]]: i1,
59+
// CHECK-SAME: %[[VAL_1:.*]]: i64,
60+
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i64 {
61+
// CHECK: %[[VAL_3:.*]] = arith.extui %[[VAL_2]] : i32 to i64
62+
// CHECK: %[[VAL_4:.*]] = arith.select %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : i64
63+
// CHECK: return %[[VAL_4]] : i64
64+
// CHECK: }
65+
66+
67+
// NOTE (i8) 255 = (i8) -1
68+
69+
// CHECK-LABEL: func.func @foo3(
70+
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9]*]]: i1,
71+
// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9]*]]: i1,
72+
// CHECK-SAME: %[[VAL_2:.*]]: i8) -> i32 {
73+
// CHECK: %[[VAL_3:.*]] = arith.constant -1 : i8
74+
// CHECK: %[[VAL_4:.*]] = arith.select %[[VAL_0]], %[[VAL_3]], %[[VAL_2]] : i8
75+
// CHECK: %[[VAL_5:.*]] = arith.extui %[[VAL_4]] : i8 to i32
76+
// CHECK: return %[[VAL_5]] : i32
77+
// CHECK: }
78+
79+
// CHECK-LABEL: func.func @foo3.1() -> i8 {
80+
// CHECK: %[[VAL_0:.*]] = arith.constant -1 : i8
81+
// CHECK: return %[[VAL_0]] : i8
82+
// CHECK: }
83+
84+
// CHECK-LABEL: func.func @foo4(
85+
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9]*]]: i1,
86+
// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9]*]]: i1,
87+
// CHECK-SAME: %[[VAL_2:.*]]: i8) -> i32 {
88+
// CHECK: %[[VAL_3:.*]] = arith.constant 256 : i32
89+
// CHECK: %[[VAL_4:.*]] = arith.extui %[[VAL_2]] : i8 to i32
90+
// CHECK: %[[VAL_5:.*]] = arith.select %[[VAL_0]], %[[VAL_3]], %[[VAL_4]] : i32
91+
// CHECK: return %[[VAL_5]] : i32
92+
// CHECK: }
93+

0 commit comments

Comments
 (0)