Skip to content

Commit 983c12b

Browse files
committed
Add unroll pattern for StepOp
1 parent ada9da7 commit 983c12b

File tree

5 files changed

+83
-1
lines changed

5 files changed

+83
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
30183018

30193019
def Vector_StepOp : Vector_Op<"step", [
30203020
Pure,
3021+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
30213022
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
30223023
]> {
30233024
let summary = "A linear sequence of values from 0 to N";

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
74427442
setResultRanges(getResult(), result);
74437443
}
74447444

7445+
std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
7446+
return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
7447+
}
7448+
74457449
//===----------------------------------------------------------------------===//
74467450
// Vector Masking Utilities
74477451
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
809809
vector::UnrollVectorOptions options;
810810
};
811811

812+
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
813+
UnrollStepPattern(MLIRContext *context,
814+
const vector::UnrollVectorOptions &options,
815+
PatternBenefit benefit = 1)
816+
: OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
817+
818+
LogicalResult matchAndRewrite(vector::StepOp stepOp,
819+
PatternRewriter &rewriter) const override {
820+
auto targetShape = getTargetShape(options, stepOp);
821+
if (!targetShape)
822+
return failure();
823+
824+
VectorType vecType = stepOp.getType();
825+
if (vecType.isScalable()) {
826+
// Scalable vectors are not supported by this pattern.
827+
return failure();
828+
}
829+
int64_t originalSize = vecType.getShape()[0];
830+
Location loc = stepOp.getLoc();
831+
SmallVector<int64_t> strides(1, 1);
832+
833+
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
834+
rewriter.getZeroAttr(vecType));
835+
836+
for (SmallVector<int64_t> offsets :
837+
StaticTileOffsetRange({originalSize}, *targetShape)) {
838+
int64_t tileOffset = offsets[0];
839+
auto targetVecType =
840+
VectorType::get(*targetShape, vecType.getElementType());
841+
Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
842+
Value offsetVal =
843+
rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
844+
Value bcastOffset =
845+
rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
846+
Value tileStep =
847+
rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
848+
849+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
850+
loc, tileStep, result, offsets, strides);
851+
}
852+
rewriter.replaceOp(stepOp, result);
853+
return success();
854+
}
855+
856+
private:
857+
vector::UnrollVectorOptions options;
858+
};
859+
812860
} // namespace
813861

814862
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
818866
UnrollContractionPattern, UnrollElementwisePattern,
819867
UnrollReductionPattern, UnrollMultiReductionPattern,
820868
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821-
UnrollStorePattern, UnrollBroadcastPattern>(
869+
UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
822870
patterns.getContext(), options, benefit);
823871
}

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
420420
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
421421
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
422422
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
423+
424+
425+
func.func @vector_step() -> vector<32xindex> {
426+
%0 = vector.step : vector<32xindex>
427+
return %0 : vector<32xindex>
428+
}
429+
// CHECK-LABEL: func @vector_step
430+
// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
431+
// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
432+
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
433+
// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
434+
// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
435+
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
436+
// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
437+
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
438+
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
439+
// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
440+
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
441+
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
442+
// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
443+
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
444+
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
445+
// CHECK: return %[[INS3]] : vector<32xindex>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
172172
.setFilterConstraint([](Operation *op) {
173173
return success(isa<vector::ReductionOp>(op));
174174
}));
175+
populateVectorUnrollPatterns(patterns,
176+
UnrollVectorOptions()
177+
.setNativeShape(ArrayRef<int64_t>{8})
178+
.setFilterConstraint([](Operation *op) {
179+
return success(isa<vector::StepOp>(op));
180+
}));
175181
populateVectorUnrollPatterns(
176182
patterns, UnrollVectorOptions()
177183
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

0 commit comments

Comments
 (0)