diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..d05fea3a5d755 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1673,7 +1673,9 @@ def Vector_TransferWriteOp : let hasVerifier = 1; } -def Vector_LoadOp : Vector_Op<"load"> { +def Vector_LoadOp : Vector_Op<"load", [ + DeclareOpInterfaceMethods, + ]> { let summary = "reads an n-D slice of memory into an n-D vector"; let description = [{ The 'vector.load' operation reads an n-D slice of memory into an n-D @@ -1759,7 +1761,9 @@ def Vector_LoadOp : Vector_Op<"load"> { "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; } -def Vector_StoreOp : Vector_Op<"store"> { +def Vector_StoreOp : Vector_Op<"store", [ + DeclareOpInterfaceMethods, + ]> { let summary = "writes an n-D vector to an n-D slice of memory"; let description = [{ The 'vector.store' operation writes an n-D vector to an n-D slice of memory. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3179b4f975404..1d0d0ec3c2fc9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5266,6 +5266,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) { return OpFoldResult(); } +std::optional> LoadOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -5301,6 +5305,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor, return memref::foldMemRefCast(*this); } +std::optional> StoreOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fc443ab0d138e..693f4f955994d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -54,6 +54,28 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, return slicedIndices; } +// Compute the new indices by adding `offsets` to `originalIndices`. +// If m < n (m = offsets.size(), n = originalIndices.size()), +// then only the trailing m values in `originalIndices` are updated. +static SmallVector sliceLoadStoreIndices(PatternRewriter &rewriter, + Location loc, + OperandRange originalIndices, + ArrayRef offsets) { + assert(offsets.size() <= originalIndices.size() && + "Offsets should not exceed the number of original indices"); + SmallVector indices(originalIndices); + + auto start = indices.size() - offsets.size(); + for (auto [i, offset] : llvm::enumerate(offsets)) { + if (offset != 0) { + indices[start + i] = rewriter.create( + loc, originalIndices[start + i], + rewriter.create(loc, offset)); + } + } + return indices; +} + // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, @@ -631,6 +653,90 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +struct UnrollLoadPattern : public OpRewritePattern { + UnrollLoadPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + VectorType vecType = loadOp.getVectorType(); + + auto targetShape = getTargetShape(options, loadOp); + if (!targetShape) + return failure(); + + Location loc = loadOp.getLoc(); + ArrayRef originalShape = vecType.getShape(); + SmallVector strides(targetShape->size(), 1); + + Value result = rewriter.create( + loc, vecType, rewriter.getZeroAttr(vecType)); + + SmallVector loopOrder = + getUnrollOrder(originalShape.size(), loadOp, options); + + auto targetVecType = + VectorType::get(*targetShape, vecType.getElementType()); + + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { + SmallVector indices = + sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets); + Value slicedLoad = rewriter.create( + loc, targetVecType, loadOp.getBase(), indices); + result = rewriter.createOrFold( + loc, slicedLoad, result, offsets, strides); + } + rewriter.replaceOp(loadOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + +struct UnrollStorePattern : public OpRewritePattern { + UnrollStorePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + VectorType vecType = storeOp.getVectorType(); + + auto targetShape = getTargetShape(options, storeOp); + if (!targetShape) + return failure(); + + Location loc = storeOp.getLoc(); + ArrayRef originalShape = vecType.getShape(); + SmallVector strides(targetShape->size(), 1); + + Value base = storeOp.getBase(); + Value vector = storeOp.getValueToStore(); + + SmallVector loopOrder = + getUnrollOrder(originalShape.size(), storeOp, options); + + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { + SmallVector indices = + sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets); + Value slice = rewriter.createOrFold( + loc, vector, offsets, *targetShape, strides); + rewriter.create(loc, slice, base, indices); + } + rewriter.eraseOp(storeOp); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + struct UnrollBroadcastPattern : public OpRewritePattern { UnrollBroadcastPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, @@ -699,10 +805,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern { void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - patterns - .add( - patterns.getContext(), options, benefit); + patterns.add( + patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index fbb178fb49d87..e129cd5c40b9c 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -378,3 +378,45 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector // CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32> // CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK: return [[r3]] : vector<4x4xf32> + + +func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> { + %c0 = arith.constant 0 : index + %0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// CHECK-LABEL: func.func @vector_load_2D( +// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: return %[[V7]] : vector<4x4xf16> + + +func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { + %c0 = arith.constant 0 : index + vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return +} + +// CHECK-LABEL: func.func @vector_store_2D( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 54aa96ba89a00..71000b98fb8f9 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -163,7 +163,8 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success( isa(op)); + vector::BroadcastOp, vector::LoadOp, vector::StoreOp>( + op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions()