Skip to content

Commit 42522d1

Browse files
committed
Address Feedback
1 parent f2af423 commit 42522d1

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
10551055

10561056
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
10571057
rewriter.getZeroAttr(resultType));
1058-
auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
1058+
VectorType targetVectorType =
1059+
VectorType::get(*targetShape, rewriter.getI1Type());
10591060
SmallVector<int64_t> strides(targetShape->size(), 1);
10601061

10611062
// In each dimension (d), each unrolled vector computes its mask size as:
@@ -1068,20 +1069,20 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
10681069
llvm::enumerate(createMaskOp.getOperands())) {
10691070
Value offsetVal =
10701071
arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1071-
Value adjustedMaskSize = arith::SubIOp::create(
1072-
rewriter, loc, originalMaskOperand, offsetVal);
1072+
Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1073+
loc, originalMaskOperand, offsetVal);
10731074
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
10741075
Value unrolledDimSize =
10751076
arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
10761077
Value nonNegative =
1077-
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1078-
Value unrolledOperand =
1079-
arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
1078+
rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079+
Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1080+
loc, nonNegative, unrolledDimSize);
10801081
unrolledOperands.push_back(unrolledOperand);
10811082
}
10821083

1083-
auto unrolledMask = vector::CreateMaskOp::create(
1084-
rewriter, loc, targetVectorType, unrolledOperands);
1084+
auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1085+
loc, targetVectorType, unrolledOperands);
10851086
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
10861087
loc, unrolledMask, result, offsets, strides);
10871088
}

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,22 @@ func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1>
537537
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
538538
// CHECK: return %[[INS11]] : vector<16x16xi1>
539539

540+
func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
541+
%cst16 = arith.constant 16 : index
542+
%0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1>
543+
return %0 : vector<16x16xi1>
544+
}
545+
546+
// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
547+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
548+
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1>
549+
// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
550+
// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
551+
// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
552+
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
553+
// CHECK: return %[[S3]] : vector<16x16xi1>
554+
555+
540556
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
541557
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
542558
return %0 : vector<2x2x4xf32>

0 commit comments

Comments
 (0)