Skip to content

Commit 7f23683

Browse files
nbpatelgithub-actions[bot]
authored andcommitted
Automerge: [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (#162515)
This PR enhances the elementwise unrolling pattern to support higher rank to lower rank unroll. The approach is to add leading unit dims to lower rank targetShape to match the rank of original vector (because ExtractStridedSlice requires same rank to extractSlices), extract slice, reshape to targetShape's rank and perform the operation.
2 parents 1e40d92 + 4ff8f11 commit 7f23683

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -465,41 +465,65 @@ struct UnrollElementwisePattern : public RewritePattern {
465465
auto targetShape = getTargetShape(options, op);
466466
if (!targetShape)
467467
return failure();
468+
int64_t targetShapeRank = targetShape->size();
468469
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
469470
SmallVector<int64_t> originalSize =
470471
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
471-
// Bail-out if rank(source) != rank(target). The main limitation here is the
472-
// fact that `ExtractStridedSlice` requires the rank for the input and
473-
// output to match. If needed, we can relax this later.
474-
if (originalSize.size() != targetShape->size())
475-
return rewriter.notifyMatchFailure(
476-
op, "expected input vector rank to match target shape rank");
472+
int64_t originalShapeRank = originalSize.size();
473+
477474
Location loc = op->getLoc();
475+
476+
// Handle rank mismatch by adding leading unit dimensions to targetShape
477+
SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478+
int64_t rankDiff = originalShapeRank - targetShapeRank;
479+
std::fill(adjustedTargetShape.begin(),
480+
adjustedTargetShape.begin() + rankDiff, 1);
481+
std::copy(targetShape->begin(), targetShape->end(),
482+
adjustedTargetShape.begin() + rankDiff);
483+
484+
int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
478485
// Prepare the result vector.
479486
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
480487
rewriter.getZeroAttr(dstVecType));
481-
SmallVector<int64_t> strides(targetShape->size(), 1);
482-
VectorType newVecType =
488+
SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489+
VectorType unrolledVecType =
483490
VectorType::get(*targetShape, dstVecType.getElementType());
484491

485492
// Create the unrolled computation.
486493
for (SmallVector<int64_t> offsets :
487-
StaticTileOffsetRange(originalSize, *targetShape)) {
494+
StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
488495
SmallVector<Value> extractOperands;
489496
for (OpOperand &operand : op->getOpOperands()) {
490497
auto vecType = dyn_cast<VectorType>(operand.get().getType());
491498
if (!vecType) {
492499
extractOperands.push_back(operand.get());
493500
continue;
494501
}
495-
extractOperands.push_back(
496-
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
497-
loc, operand.get(), offsets, *targetShape, strides));
502+
Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
503+
loc, operand.get(), offsets, adjustedTargetShape, strides);
504+
505+
// Reshape to remove leading unit dims if needed
506+
if (adjustedTargetShapeRank > targetShapeRank) {
507+
extracted = rewriter.createOrFold<vector::ShapeCastOp>(
508+
loc, VectorType::get(*targetShape, vecType.getElementType()),
509+
extracted);
510+
}
511+
extractOperands.push_back(extracted);
498512
}
513+
499514
Operation *newOp = cloneOpWithOperandsAndTypes(
500-
rewriter, loc, op, extractOperands, newVecType);
515+
rewriter, loc, op, extractOperands, unrolledVecType);
516+
517+
Value computeResult = newOp->getResult(0);
518+
519+
// Use strides sized to targetShape for proper insertion
520+
SmallVector<int64_t> insertStrides =
521+
(adjustedTargetShapeRank > targetShapeRank)
522+
? SmallVector<int64_t>(targetShapeRank, 1)
523+
: strides;
524+
501525
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
502-
loc, newOp->getResult(0), result, offsets, strides);
526+
loc, computeResult, result, offsets, insertStrides);
503527
}
504528
rewriter.replaceOp(op, result);
505529
return success();

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
188188
// CHECK-LABEL: func @vector_fma
189189
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
190190

191-
// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
192-
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
191+
func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
193192
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
194193
return %0 : vector<3x2x2xf32>
195194
}
196-
// CHECK-LABEL: func @negative_vector_fma_3d
197-
// CHECK-NOT: vector.extract_strided_slice
198-
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
199-
// CHECK: return
195+
// CHECK-LABEL: func @vector_fma_3d
196+
// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
197+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
198+
// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
199+
// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
200+
// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
201+
// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
202+
// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
203+
// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
204+
// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
205+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
206+
// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
207+
// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
208+
// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
209+
// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
210+
// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
211+
// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
212+
// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
213+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
214+
// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
215+
// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
216+
// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
217+
// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
218+
// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
219+
// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
220+
// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
221+
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
222+
// CHECK: return %[[I2]] : vector<3x2x2xf32>
200223

201224
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
202225
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> {
440463
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
441464
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
442465
// CHECK: return %[[INS3]] : vector<32xindex>
466+
467+
468+
func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
469+
%0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
470+
return %0 : vector<2x2x2xf32>
471+
}
472+
// CHECK-LABEL: func @elementwise_3D_to_2D
473+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
474+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
475+
// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
476+
// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
477+
// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
478+
// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
479+
// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
480+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
481+
// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
482+
// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
483+
// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
484+
// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
485+
// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
486+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
487+
// CHECK: return %[[I1]] : vector<2x2x2xf32>
488+
489+
490+
func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
491+
%0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
492+
return %0 : vector<2x2x2x2xf32>
493+
}
494+
495+
// CHECK-LABEL: func @elementwise_4D_to_2D
496+
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
497+
// CHECK-NOT: arith.addf
498+
// CHECK: return

0 commit comments

Comments
 (0)