Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3018,6 +3018,7 @@ def Vector_ScanOp :

def Vector_StepOp : Vector_Op<"step", [
Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
Expand Down
77 changes: 76 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};

/// This pattern unrolls `vector.step` operations according to the provided
/// target unroll shape. It decomposes a large step vector into smaller step
/// vectors (segments) and assembles the result by inserting each computed
/// segment into the appropriate offset of the original vector.
///
/// The pattern does not support scalable vectors and will fail to match them.
///
/// For each segment, it adds the base step vector and the segment's offset,
/// then inserts the result into the output vector at the corresponding
/// position.
///
/// Example:
/// Given a step operation:
/// %0 = vector.step : vector<8xindex>
///
/// and a target unroll shape of <4>, the pattern produces:
///
/// %base = vector.step : vector<4xindex>
/// %zero = arith.constant dense<0> : vector<8xindex>
/// %result0 = vector.insert_strided_slice %base, %zero
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
/// %offset = arith.constant dense<4> : vector<4xindex>
/// %segment1 = arith.addi %base, %offset : vector<4xindex>
/// %result1 = vector.insert_strided_slice %segment1, %result0
/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
///
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
UnrollStepPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}

LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, stepOp);
if (!targetShape)
return failure();

VectorType vecType = stepOp.getType();
if (vecType.isScalable()) {
// Scalable vectors are not supported by this pattern.
return failure();
}
int64_t originalSize = vecType.getShape()[0];
Location loc = stepOp.getLoc();
SmallVector<int64_t> strides(1, 1);

Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));

VectorType targetVecType =
VectorType::get(*targetShape, vecType.getElementType());
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange({originalSize}, *targetShape)) {
Value bcastOffset = arith::ConstantOp::create(
rewriter, loc, targetVecType,
DenseElementsAttr::get(
targetVecType,
IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
Value tileStep =
arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);

result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tileStep, result, offsets, strides);
}
rewriter.replaceOp(stepOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
Expand All @@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern>(
UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,23 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>


func.func @vector_step() -> vector<32xindex> {
%0 = vector.step : vector<32xindex>
return %0 : vector<32xindex>
}
// CHECK-LABEL: func @vector_step
// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{8})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
Expand Down