Skip to content

Commit 4dde55a

Browse files
committed
Add a negative test and modify the return statement
1 parent def5a8b commit 4dde55a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,9 @@ struct UnrollMultiReductionPattern
356356
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357357
PatternRewriter &rewriter) const override {
358358
auto resultType = reductionOp->getResult(0).getType();
359-
if (mlir::isa<mlir::FloatType>(resultType) ||
360-
mlir::isa<mlir::IntegerType>(resultType)) {
361-
return failure();
359+
if (resultType.isIntOrFloat()) {
360+
return rewriter.notifyMatchFailure(reductionOp,
361+
"Unrolling scalars is not supported");
362362
}
363363
std::optional<SmallVector<int64_t>> targetShape =
364364
getTargetShape(options, reductionOp);

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
222222
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
223223
// CHECK: return %[[V2]] : vector<4xf32>
224224

225+
func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
226+
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
227+
return %0 : f32
228+
}
229+
// CHECK-LABEL: func @negative_vector_multi_reduction
230+
// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
231+
// CHECK: return %[[R0]] : f32
225232

226233
func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
227234
%0 = vector.reduction <add>, %v : vector<8xf32> into f32

0 commit comments

Comments
 (0)