Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern

LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
auto resultType = reductionOp->getResult(0).getType();
if (resultType.isIntOrFloat()) {
return rewriter.notifyMatchFailure(reductionOp,
"Unrolling scalars is not supported");
}
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>

// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction
// operation has already been unrolled, attempting additional unrolling should not be allowed.
func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
return %0 : f32
}
// CHECK-LABEL: func @negative_vector_multi_reduction
// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
// CHECK-NEXT: return %[[R0]] : f32

func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
Expand Down