Skip to content

Commit f7b5485

Browse files
committed
add vector support
1 parent 1f8358b commit f7b5485

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
197197

198198
/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199199
static Type checkIntType(Type type, unsigned targetBitwidth) {
200-
type = getElementTypeOrSelf(type);
201-
if (isa<IndexType>(type))
200+
Type elemType = getElementTypeOrSelf(type);
201+
if (isa<IndexType>(elemType))
202202
return type;
203203

204-
if (auto intType = dyn_cast<IntegerType>(type))
204+
if (auto intType = dyn_cast<IntegerType>(elemType))
205205
if (intType.getWidth() > targetBitwidth)
206206
return type;
207207

@@ -298,16 +298,20 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
298298

299299
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
300300
Type srcType = src.getType();
301-
assert(srcType.isIntOrIndex() && "Invalid src type");
302-
assert(dstType.isIntOrIndex() && "Invalid dst type");
301+
assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
302+
"Mixing vector and non-vector types");
303+
Type srcElemType = getElementTypeOrSelf(srcType);
304+
Type dstElemType = getElementTypeOrSelf(dstType);
305+
assert(srcElemType.isIntOrIndex() && "Invalid src type");
306+
assert(dstElemType.isIntOrIndex() && "Invalid dst type");
303307
if (srcType == dstType)
304308
return src;
305309

306-
if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
310+
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
307311
return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
308312

309-
auto srcInt = cast<IntegerType>(srcType);
310-
auto dstInt = cast<IntegerType>(dstType);
313+
auto srcInt = cast<IntegerType>(srcElemType);
314+
auto dstInt = cast<IntegerType>(dstElemType);
311315
if (dstInt.getWidth() < srcInt.getWidth())
312316
return builder.create<arith::TruncIOp>(loc, dstType, src);
313317

mlir/test/Dialect/Arith/int-range-narrowing.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ func.func @test_addi() -> index {
3030
return %2 : index
3131
}
3232

33+
// CHECK-LABEL: func @test_addi_vec
34+
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
35+
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
36+
// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
37+
// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
38+
// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
39+
// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex>
40+
// CHECK: return %[[RES_CASTED]] : vector<4xindex>
41+
func.func @test_addi_vec() -> vector<4xindex> {
42+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
43+
%1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
44+
%2 = arith.addi %0, %1 : vector<4xindex>
45+
return %2 : vector<4xindex>
46+
}
3347

3448
// CHECK-LABEL: func @test_addi_i64
3549
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
@@ -60,6 +74,20 @@ func.func @test_cmpi() -> i1 {
6074
return %2 : i1
6175
}
6276

77+
// CHECK-LABEL: func @test_cmpi_vec
78+
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
79+
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
80+
// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
81+
// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
82+
// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
83+
// CHECK: return %[[RES]] : vector<4xi1>
84+
func.func @test_cmpi_vec() -> vector<4xi1> {
85+
%0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex>
86+
%1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex>
87+
%2 = arith.cmpi slt, %0, %1 : vector<4xindex>
88+
return %2 : vector<4xi1>
89+
}
90+
6391
//===----------------------------------------------------------------------===//
6492
// arith.addi
6593
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)