Skip to content

Commit 2535f83

Browse files
committed
Update return method in VectorUnroll.cpp and modify test in vector-unroll-options.mlir
1 parent 77b4b65 commit 2535f83

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ struct UnrollElementwisePattern : public RewritePattern {
441441
// fact that `ExtractStridedSlice` requires the rank for the input and
442442
// output to match. If needed, we can relax this later.
443443
if (originalSize.size() != targetShape->size())
444-
return failure();
444+
return rewriter.notifyMatchFailure(op, "expected input vector rank to match target shape rank");
445445
Location loc = op->getLoc();
446446
// Prepare the result vector.
447447
Value result = rewriter.create<arith::ConstantOp>(

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,8 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
769769
// -----
770770

771771
func.func @extract_strided_slice(%arg0: vector<3x2x2xf32>) {
772-
// expected-error@+1 {{expected input vector rank to match target shape rank}}
773-
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
774-
vector<3x2x2xf32> to vector<2x2xf32>
772+
// expected-error@+1 {{op expected input vector rank to match target shape rank}}
773+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}: vector<3x2x2xf32> to vector<2x2xf32>
775774
return
776775
}
777776

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,13 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
189189
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
190190

191191
// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
192-
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
192+
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
193193
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
194-
return
194+
return %0 : vector<3x2x2xf32>
195195
}
196196
// CHECK-LABEL: func @negative_vector_fma_3d
197197
// CHECK-NOT: vector.extract_strided_slice
198-
// CHECK: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<3x2x2xf32>
198+
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
199199
// CHECK: return
200200

201201
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {

0 commit comments

Comments
 (0)