Skip to content

Commit 4707325

Browse files
committed
Allow unrolling to drop leading unit dimensions
1 parent 08619d8 commit 4707325

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,27 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
122122
return std::nullopt;
123123
}
124124
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125-
LDBG() << "--no unrolling needed -> SKIP";
126-
return std::nullopt;
125+
// If maybeShapeRatio are all 1s, only allow unrolling for leading unit
126+
// dimension removal: [1,1,...,n] -> [n]
127+
if (maybeUnrollShape->size() <= targetShape->size()) {
128+
LDBG() << "--no dimension reduction -> SKIP";
129+
return std::nullopt;
130+
}
131+
132+
size_t dimDiff = maybeUnrollShape->size() - targetShape->size();
133+
ArrayRef<int64_t> srcShape = *maybeUnrollShape;
134+
ArrayRef<int64_t> tgtShape = *targetShape;
135+
136+
// Check leading dimensions are 1s and remaining matches target
137+
bool isValidRemoval = llvm::all_of(srcShape.slice(0, dimDiff),
138+
[](int64_t dim) { return dim == 1; }) &&
139+
srcShape.slice(dimDiff) == tgtShape;
140+
141+
if (!isValidRemoval) {
142+
LDBG() << "--not a valid leading unit dimension removal -> SKIP";
143+
return std::nullopt;
144+
}
145+
LDBG() << "--leading unit dimension removal -> CONTINUE";
127146
}
128147
LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129148
return targetShape;

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,18 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
496496
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
497497
// CHECK-NOT: arith.addf
498498
// CHECK: return
499+
500+
501+
func.func @elementwise_leading_unit_dim(%v1: vector<1x2x2xf32>, %v2: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
502+
%0 = arith.addf %v1, %v2 : vector<1x2x2xf32>
503+
return %0 : vector<1x2x2xf32>
504+
}
505+
506+
// CHECK-LABEL: func @elementwise_leading_unit_dim
507+
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x2x2xf32>, %[[ARG1:.*]]: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
508+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x2x2xf32>
509+
// CHECK: %[[S_LHS:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2x2xf32> to vector<2x2xf32>
510+
// CHECK: %[[S_RHS:.*]] = vector.shape_cast %[[ARG1]] : vector<1x2x2xf32> to vector<2x2xf32>
511+
// CHECK: %[[ADD:.*]] = arith.addf %[[S_LHS]], %[[S_RHS]] : vector<2x2xf32>
512+
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ADD]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<1x2x2xf32>
513+
// CHECK: return %[[INS]] : vector<1x2x2xf32>

0 commit comments

Comments
 (0)