Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,41 +465,65 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
// Bail-out if rank(source) != rank(target). The main limitation here is the
// fact that `ExtractStridedSlice` requires the rank for the input and
// output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return rewriter.notifyMatchFailure(
op, "expected input vector rank to match target shape rank");
int64_t originalShapeRank = originalSize.size();

Location loc = op->getLoc();

// Handle rank mismatch by adding leading unit dimensions to targetShape
SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
int64_t rankDiff = originalShapeRank - targetShapeRank;
std::fill(adjustedTargetShape.begin(),
adjustedTargetShape.begin() + rankDiff, 1);
std::copy(targetShape->begin(), targetShape->end(),
adjustedTargetShape.begin() + rankDiff);

int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t> strides(targetShape->size(), 1);
VectorType newVecType =
SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());

// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
}
extractOperands.push_back(
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, adjustedTargetShape, strides);

// Reshape to remove leading unit dims if needed
if (adjustedTargetShapeRank > targetShapeRank) {
extracted = rewriter.createOrFold<vector::ShapeCastOp>(
loc, VectorType::get(*targetShape, vecType.getElementType()),
extracted);
}
extractOperands.push_back(extracted);
}

Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
rewriter, loc, op, extractOperands, unrolledVecType);

Value computeResult = newOp->getResult(0);

// Use strides sized to targetShape for proper insertion
SmallVector<int64_t> insertStrides =
(adjustedTargetShapeRank > targetShapeRank)
? SmallVector<int64_t>(targetShapeRank, 1)
: strides;

result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
Expand Down
68 changes: 62 additions & 6 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>

// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return %0 : vector<3x2x2xf32>
}
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
// CHECK: return
// CHECK-LABEL: func @vector_fma_3d
// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: return %[[I2]] : vector<3x2x2xf32>

func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
Expand Down Expand Up @@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> {
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>


func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
%0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
return %0 : vector<2x2x2xf32>
}
// CHECK-LABEL: func @elementwise_3D_to_2D
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: return %[[I1]] : vector<2x2x2xf32>


func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
%0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
return %0 : vector<2x2x2x2xf32>
}

// CHECK-LABEL: func @elementwise_4D_to_2D
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return