Skip to content

Commit 5003057

Browse files
committed
Add unroll patterns for vector.load and vector.store
1 parent 0dd2c9f commit 5003057

File tree

3 files changed

+234
-2
lines changed

3 files changed

+234
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454
return slicedIndices;
5555
}
5656

57+
// compute the new indices for vector.load/store by adding offsets to
58+
// originalIndices.
59+
// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
60+
// Last m of originalIndices will be updated.
61+
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
62+
Location loc,
63+
ArrayRef<Value> originalIndices,
64+
ArrayRef<int64_t> offsets) {
65+
assert(offsets.size() <= originalIndices.size() &&
66+
"Offsets should not exceed the number of original indices");
67+
SmallVector<Value> indices(originalIndices);
68+
auto originalIter = originalIndices.rbegin();
69+
auto offsetsIter = offsets.rbegin();
70+
auto indicesIter = indices.rbegin();
71+
while (offsetsIter != offsets.rend()) {
72+
Value original = *originalIter;
73+
int64_t offset = *offsetsIter;
74+
if (offset != 0)
75+
*indicesIter = rewriter.create<arith::AddIOp>(
76+
loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
77+
originalIter++;
78+
offsetsIter++;
79+
indicesIter++;
80+
}
81+
return indices;
82+
};
83+
5784
// Clones `op` into a new operations that takes `operands` and returns
5885
// `resultTypes`.
5986
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631658
vector::UnrollVectorOptions options;
632659
};
633660

661+
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
662+
UnrollLoadPattern(MLIRContext *context,
663+
const vector::UnrollVectorOptions &options,
664+
PatternBenefit benefit = 1)
665+
: OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
666+
667+
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
668+
PatternRewriter &rewriter) const override {
669+
VectorType vecType = loadOp.getVectorType();
670+
// Only unroll >1D loads
671+
if (vecType.getRank() <= 1)
672+
return failure();
673+
674+
Location loc = loadOp.getLoc();
675+
ArrayRef<int64_t> originalShape = vecType.getShape();
676+
677+
// Target type is a 1D vector of the innermost dimension.
678+
auto targetType =
679+
VectorType::get(originalShape.back(), vecType.getElementType());
680+
681+
// Extend the targetShape to the same rank of original shape by padding 1s
682+
// for leading dimensions for convenience of computing offsets
683+
SmallVector<int64_t> targetShape(originalShape.size(), 1);
684+
targetShape.back() = originalShape.back();
685+
686+
Value result = rewriter.create<arith::ConstantOp>(
687+
loc, vecType, rewriter.getZeroAttr(vecType));
688+
689+
SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
690+
loadOp.getIndices().end());
691+
692+
for (SmallVector<int64_t> offsets :
693+
StaticTileOffsetRange(originalShape, targetShape)) {
694+
SmallVector<Value> indices =
695+
computeIndices(rewriter, loc, originalIndices, offsets);
696+
Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
697+
loadOp.getBase(), indices);
698+
// Insert the slice into the result at the correct position.
699+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
700+
loc, slice, result, offsets, SmallVector<int64_t>({1}));
701+
}
702+
rewriter.replaceOp(loadOp, result);
703+
return success();
704+
}
705+
706+
private:
707+
vector::UnrollVectorOptions options;
708+
};
709+
710+
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
711+
UnrollStorePattern(MLIRContext *context,
712+
const vector::UnrollVectorOptions &options,
713+
PatternBenefit benefit = 1)
714+
: OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
715+
716+
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
717+
PatternRewriter &rewriter) const override {
718+
VectorType vecType = storeOp.getVectorType();
719+
// Only unroll >1D stores.
720+
if (vecType.getRank() <= 1)
721+
return failure();
722+
723+
Location loc = storeOp.getLoc();
724+
ArrayRef<int64_t> originalShape = vecType.getShape();
725+
726+
// Extend the targetShape to the same rank of original shape by padding 1s
727+
// for leading dimensions for convenience of computing offsets
728+
SmallVector<int64_t> targetShape(originalShape.size(), 1);
729+
targetShape.back() = originalShape.back();
730+
731+
Value base = storeOp.getBase();
732+
Value vector = storeOp.getValueToStore();
733+
734+
SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
735+
storeOp.getIndices().end());
736+
737+
for (SmallVector<int64_t> offsets :
738+
StaticTileOffsetRange(originalShape, targetShape)) {
739+
SmallVector<Value> indices =
740+
computeIndices(rewriter, loc, originalIndices, offsets);
741+
offsets.pop_back();
742+
Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
743+
rewriter.create<vector::StoreOp>(loc, slice, base, indices);
744+
}
745+
rewriter.eraseOp(storeOp);
746+
return success();
747+
}
748+
749+
private:
750+
vector::UnrollVectorOptions options;
751+
};
752+
634753
} // namespace
635754

636755
void mlir::vector::populateVectorUnrollPatterns(
@@ -639,6 +758,6 @@ void mlir::vector::populateVectorUnrollPatterns(
639758
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
640759
UnrollContractionPattern, UnrollElementwisePattern,
641760
UnrollReductionPattern, UnrollMultiReductionPattern,
642-
UnrollTransposePattern, UnrollGatherPattern>(
643-
patterns.getContext(), options, benefit);
761+
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
762+
UnrollStorePattern>(patterns.getContext(), options, benefit);
644763
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @unroll_2D_vector_load(
4+
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
5+
func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
6+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
7+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
8+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
9+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
10+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
11+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
12+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
13+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
14+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
15+
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
16+
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
17+
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
18+
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
19+
// CHECK: return %[[V7]] : vector<4x4xf16>
20+
%c0 = arith.constant 0 : index
21+
%0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
22+
return %0 : vector<4x4xf16>
23+
}
24+
25+
// CHECK-LABEL: func.func @unroll_2D_vector_store(
26+
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
27+
func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
28+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
29+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
30+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
31+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
32+
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
33+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
34+
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
35+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
36+
// CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
37+
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
38+
// CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
39+
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
40+
%c0 = arith.constant 0 : index
41+
vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
42+
return
43+
}
44+
45+
// CHECK-LABEL: func.func @unroll_vector_load(
46+
// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
47+
func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
48+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
49+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
50+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
51+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
52+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
53+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
54+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
55+
// CHECK: return %[[V3]] : vector<2x2xf16>
56+
%c1 = arith.constant 1 : index
57+
%0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
58+
return %0 : vector<2x2xf16>
59+
}
60+
61+
// CHECK-LABEL: func.func @unroll_vector_store(
62+
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
63+
func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
64+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
65+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
66+
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
67+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
68+
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
69+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
70+
%c1 = arith.constant 1 : index
71+
vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
72+
return
73+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,44 @@ struct TestVectorTransferUnrollingPatterns
289289
llvm::cl::init(false)};
290290
};
291291

292+
struct TestVectorLoadStoreUnrollPatterns
293+
: public PassWrapper<TestVectorLoadStoreUnrollPatterns,
294+
OperationPass<func::FuncOp>> {
295+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
296+
TestVectorLoadStoreUnrollPatterns)
297+
298+
StringRef getArgument() const final {
299+
return "test-vector-load-store-unroll";
300+
}
301+
StringRef getDescription() const final {
302+
return "Test unrolling patterns for vector.load and vector.store ops";
303+
}
304+
305+
void getDependentDialects(DialectRegistry &registry) const override {
306+
registry.insert<vector::VectorDialect, arith::ArithDialect>();
307+
}
308+
309+
void runOnOperation() override {
310+
MLIRContext *ctx = &getContext();
311+
RewritePatternSet patterns(ctx);
312+
313+
// Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
314+
vector::UnrollVectorOptions options;
315+
options.setFilterConstraint([](Operation *op) {
316+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
317+
return success(loadOp.getType().getRank() > 1);
318+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
319+
return success(storeOp.getVectorType().getRank() > 1);
320+
return failure();
321+
});
322+
323+
vector::populateVectorUnrollPatterns(patterns, options);
324+
325+
// Apply the patterns
326+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
327+
}
328+
};
329+
292330
struct TestScalarVectorTransferLoweringPatterns
293331
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
294332
OperationPass<func::FuncOp>> {
@@ -1033,6 +1071,8 @@ void registerTestVectorLowerings() {
10331071

10341072
PassRegistration<TestVectorTransferUnrollingPatterns>();
10351073

1074+
PassRegistration<TestVectorLoadStoreUnrollPatterns>();
1075+
10361076
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
10371077

10381078
PassRegistration<TestVectorTransferOpt>();

0 commit comments

Comments
 (0)