Skip to content

Commit 57cc380

Browse files
committed
Use unroll options
1 parent 5a2070b commit 57cc380

File tree

5 files changed

+56
-115
lines changed

5 files changed

+56
-115
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,9 @@ def Vector_TransferWriteOp :
16731673
let hasVerifier = 1;
16741674
}
16751675

1676-
def Vector_LoadOp : Vector_Op<"load"> {
1676+
def Vector_LoadOp : Vector_Op<"load", [
1677+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1678+
]> {
16771679
let summary = "reads an n-D slice of memory into an n-D vector";
16781680
let description = [{
16791681
The 'vector.load' operation reads an n-D slice of memory into an n-D
@@ -1759,7 +1761,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
17591761
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
17601762
}
17611763

1762-
def Vector_StoreOp : Vector_Op<"store"> {
1764+
def Vector_StoreOp : Vector_Op<"store", [
1765+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1766+
]> {
17631767
let summary = "writes an n-D vector to an n-D slice of memory";
17641768
let description = [{
17651769
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5266,6 +5266,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
52665266
return OpFoldResult();
52675267
}
52685268

5269+
std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5270+
return llvm::to_vector<4>(getVectorType().getShape());
5271+
}
5272+
52695273
//===----------------------------------------------------------------------===//
52705274
// StoreOp
52715275
//===----------------------------------------------------------------------===//
@@ -5301,6 +5305,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
53015305
return memref::foldMemRefCast(*this);
53025306
}
53035307

5308+
std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5309+
return llvm::to_vector<4>(getVectorType().getShape());
5310+
}
5311+
53045312
//===----------------------------------------------------------------------===//
53055313
// MaskedLoadOp
53065314
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -653,21 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
653653
vector::UnrollVectorOptions options;
654654
};
655655

656-
// This pattern unrolls the vector load into multiple 1D vector loads by
657-
// extracting slices from the base memory and inserting them into the result
658-
// vector using vector.insert_strided_slice.
659-
// Following,
660-
// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32>
661-
// is converted to :
662-
// %cst = arith.constant dense<0.0> : vector<4x4xf32>
663-
// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
664-
// %result_0 = vector.insert_strided_slice %slice_0, %cst
665-
// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
666-
// %slice_1 = vector.load %base[%indices + 1]
667-
// : memref<4x4xf32>, vector<4xf32>
668-
// %result_1 = vector.insert_strided_slice %slice_1, %result_0
669-
// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
670-
// ...
671656
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
672657
UnrollLoadPattern(MLIRContext *context,
673658
const vector::UnrollVectorOptions &options,
@@ -677,37 +662,37 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
677662
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
678663
PatternRewriter &rewriter) const override {
679664
VectorType vecType = loadOp.getVectorType();
680-
// Only unroll >1D loads
681665
if (vecType.getRank() <= 1)
682666
return failure();
683667

668+
auto targetShape = getTargetShape(options, loadOp);
669+
if (!targetShape)
670+
return failure();
671+
684672
Location loc = loadOp.getLoc();
685673
ArrayRef<int64_t> originalShape = vecType.getShape();
686-
687-
// Target type is a 1D vector of the innermost dimension.
688-
auto targetType =
689-
VectorType::get(originalShape.back(), vecType.getElementType());
690-
691-
// Extend the targetShape to the same rank of original shape by padding 1s
692-
// for leading dimensions for convenience of computing offsets
693-
SmallVector<int64_t> targetShape(originalShape.size(), 1);
694-
targetShape.back() = originalShape.back();
674+
SmallVector<int64_t> strides(targetShape->size(), 1);
695675

696676
Value result = rewriter.create<arith::ConstantOp>(
697677
loc, vecType, rewriter.getZeroAttr(vecType));
698678

699679
SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
700680
loadOp.getIndices().end());
701681

682+
SmallVector<int64_t> loopOrder =
683+
getUnrollOrder(originalShape.size(), loadOp, options);
684+
685+
auto targetVecType =
686+
VectorType::get(*targetShape, vecType.getElementType());
687+
702688
for (SmallVector<int64_t> offsets :
703-
StaticTileOffsetRange(originalShape, targetShape)) {
689+
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
704690
SmallVector<Value> indices =
705691
computeIndices(rewriter, loc, originalIndices, offsets);
706-
Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
692+
Value slice = rewriter.create<vector::LoadOp>(loc, targetVecType,
707693
loadOp.getBase(), indices);
708-
// Insert the slice into the result at the correct position.
709694
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
710-
loc, slice, result, offsets, SmallVector<int64_t>({1}));
695+
loc, slice, result, offsets, strides);
711696
}
712697
rewriter.replaceOp(loadOp, result);
713698
return success();
@@ -717,17 +702,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
717702
vector::UnrollVectorOptions options;
718703
};
719704

720-
// This pattern unrolls the vector store into multiple 1D vector stores by
721-
// extracting slices from the source vector and storing them into the
722-
// destination.
723-
// Following,
724-
// vector.store %source, %base[%indices] : vector<4x4xf32>
725-
// is converted to :
726-
// %slice_0 = vector.extract %source[0] : vector<4xf32>
727-
// vector.store %slice_0, %base[%indices] : vector<4xf32>
728-
// %slice_1 = vector.extract %source[1] : vector<4xf32>
729-
// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
730-
// ...
731705
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
732706
UnrollStorePattern(MLIRContext *context,
733707
const vector::UnrollVectorOptions &options,
@@ -737,30 +711,32 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
737711
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
738712
PatternRewriter &rewriter) const override {
739713
VectorType vecType = storeOp.getVectorType();
740-
// Only unroll >1D stores.
741714
if (vecType.getRank() <= 1)
742715
return failure();
743716

717+
auto targetShape = getTargetShape(options, storeOp);
718+
if (!targetShape)
719+
return failure();
720+
744721
Location loc = storeOp.getLoc();
745722
ArrayRef<int64_t> originalShape = vecType.getShape();
746-
747-
// Extend the targetShape to the same rank of original shape by padding 1s
748-
// for leading dimensions for convenience of computing offsets
749-
SmallVector<int64_t> targetShape(originalShape.size(), 1);
750-
targetShape.back() = originalShape.back();
723+
SmallVector<int64_t> strides(targetShape->size(), 1);
751724

752725
Value base = storeOp.getBase();
753726
Value vector = storeOp.getValueToStore();
754727

755728
SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
756729
storeOp.getIndices().end());
757730

731+
SmallVector<int64_t> loopOrder =
732+
getUnrollOrder(originalShape.size(), storeOp, options);
733+
758734
for (SmallVector<int64_t> offsets :
759-
StaticTileOffsetRange(originalShape, targetShape)) {
735+
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
760736
SmallVector<Value> indices =
761737
computeIndices(rewriter, loc, originalIndices, offsets);
762-
offsets.pop_back();
763-
Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
738+
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
739+
loc, vector, offsets, *targetShape, strides);
764740
rewriter.create<vector::StoreOp>(loc, slice, base, indices);
765741
}
766742
rewriter.eraseOp(storeOp);

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -388,19 +388,17 @@ func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
388388

389389
// CHECK-LABEL: func.func @vector_load_2D(
390390
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
391-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
392391
// CHECK: %[[C2:.*]] = arith.constant 2 : index
393-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
394392
// CHECK: %[[C0:.*]] = arith.constant 0 : index
395393
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
396-
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
397-
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
398-
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
399-
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
400-
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
401-
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
402-
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
403-
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
394+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
395+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
396+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
397+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
398+
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
399+
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
400+
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
401+
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
404402
// CHECK: return %[[V7]] : vector<4x4xf16>
405403

406404

@@ -412,48 +410,13 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
412410

413411
// CHECK-LABEL: func.func @vector_store_2D(
414412
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
415-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
416413
// CHECK: %[[C2:.*]] = arith.constant 2 : index
417-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
418414
// CHECK: %[[C0:.*]] = arith.constant 0 : index
419-
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
420-
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
421-
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
422-
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
423-
// CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
424-
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
425-
// CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
426-
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
427-
428-
429-
func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
430-
%c1 = arith.constant 1 : index
431-
%0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
432-
return %0 : vector<2x2xf16>
433-
}
434-
435-
// CHECK-LABEL: func.func @vector_load_4D_to_2D(
436-
// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
437-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
438-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
439-
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
440-
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
441-
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
442-
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
443-
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
444-
// CHECK: return %[[V3]] : vector<2x2xf16>
445-
446-
func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) {
447-
%c1 = arith.constant 1 : index
448-
vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
449-
return
450-
}
451-
452-
// CHECK-LABEL: func.func @vector_store_2D_to_4D(
453-
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
454-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
455-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
456-
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
457-
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
458-
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
459-
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
415+
// CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
416+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
417+
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
418+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
419+
// CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
420+
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
421+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
422+
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct TestVectorUnrollingPatterns
163163
.setFilterConstraint([](Operation *op) {
164164
return success(
165165
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
166-
vector::BroadcastOp>(op));
166+
vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(op));
167167
}));
168168
populateVectorUnrollPatterns(
169169
patterns, UnrollVectorOptions()
@@ -178,16 +178,6 @@ struct TestVectorUnrollingPatterns
178178
return success(isa<vector::TransposeOp>(op));
179179
}));
180180

181-
populateVectorUnrollPatterns(
182-
patterns, UnrollVectorOptions()
183-
.setNativeShape(ArrayRef<int64_t>{2, 2})
184-
.setFilterConstraint([](Operation *op) {
185-
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
186-
return success(loadOp.getType().getRank() > 1);
187-
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
188-
return success(storeOp.getVectorType().getRank() > 1);
189-
return failure();
190-
}));
191181
if (unrollBasedOnType) {
192182
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
193183
[](Operation *op) -> std::optional<SmallVector<int64_t>> {

0 commit comments

Comments
 (0)