From 5003057b2010149a95fda72b6dd395c918329408 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 9 Jun 2025 17:56:30 +0000 Subject: [PATCH 1/7] Add unroll patterns for vector.load and vector.store --- .../Vector/Transforms/VectorUnroll.cpp | 123 +++++++++++++++++- .../Vector/vector-load-store-unroll.mlir | 73 +++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 40 ++++++ 3 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 1cc477d9dca91..43abf84cd6428 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -54,6 +54,33 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, return slicedIndices; } +// compute the new indices for vector.load/store by adding offsets to +// originalIndices. +// It assumes m <= n (m = offsets.size(), n = originalIndices.size()) +// Last m of originalIndices will be updated. +static SmallVector computeIndices(PatternRewriter &rewriter, + Location loc, + ArrayRef originalIndices, + ArrayRef offsets) { + assert(offsets.size() <= originalIndices.size() && + "Offsets should not exceed the number of original indices"); + SmallVector indices(originalIndices); + auto originalIter = originalIndices.rbegin(); + auto offsetsIter = offsets.rbegin(); + auto indicesIter = indices.rbegin(); + while (offsetsIter != offsets.rend()) { + Value original = *originalIter; + int64_t offset = *offsetsIter; + if (offset != 0) + *indicesIter = rewriter.create( + loc, original, rewriter.create(loc, offset)); + originalIter++; + offsetsIter++; + indicesIter++; + } + return indices; +}; + // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, @@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +struct UnrollLoadPattern : public OpRewritePattern { + UnrollLoadPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + VectorType vecType = loadOp.getVectorType(); + // Only unroll >1D loads + if (vecType.getRank() <= 1) + return failure(); + + Location loc = loadOp.getLoc(); + ArrayRef originalShape = vecType.getShape(); + + // Target type is a 1D vector of the innermost dimension. + auto targetType = + VectorType::get(originalShape.back(), vecType.getElementType()); + + // Extend the targetShape to the same rank of original shape by padding 1s + // for leading dimensions for convenience of computing offsets + SmallVector targetShape(originalShape.size(), 1); + targetShape.back() = originalShape.back(); + + Value result = rewriter.create( + loc, vecType, rewriter.getZeroAttr(vecType)); + + SmallVector originalIndices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, targetShape)) { + SmallVector indices = + computeIndices(rewriter, loc, originalIndices, offsets); + Value slice = rewriter.create(loc, targetType, + loadOp.getBase(), indices); + // Insert the slice into the result at the correct position. + result = rewriter.createOrFold( + loc, slice, result, offsets, SmallVector({1})); + } + rewriter.replaceOp(loadOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + +struct UnrollStorePattern : public OpRewritePattern { + UnrollStorePattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + VectorType vecType = storeOp.getVectorType(); + // Only unroll >1D stores. + if (vecType.getRank() <= 1) + return failure(); + + Location loc = storeOp.getLoc(); + ArrayRef originalShape = vecType.getShape(); + + // Extend the targetShape to the same rank of original shape by padding 1s + // for leading dimensions for convenience of computing offsets + SmallVector targetShape(originalShape.size(), 1); + targetShape.back() = originalShape.back(); + + Value base = storeOp.getBase(); + Value vector = storeOp.getValueToStore(); + + SmallVector originalIndices(storeOp.getIndices().begin(), + storeOp.getIndices().end()); + + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, targetShape)) { + SmallVector indices = + computeIndices(rewriter, loc, originalIndices, offsets); + offsets.pop_back(); + Value slice = rewriter.create(loc, vector, offsets); + rewriter.create(loc, slice, base, indices); + } + rewriter.eraseOp(storeOp); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -639,6 +758,6 @@ void mlir::vector::populateVectorUnrollPatterns( patterns.add( - patterns.getContext(), options, benefit); + UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, + UnrollStorePattern>(patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir new file mode 100644 index 0000000000000..3135268b8d61b --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @unroll_2D_vector_load( +// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { +func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: return %[[V7]] : vector<4x4xf16> + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// CHECK-LABEL: func.func @unroll_2D_vector_store( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { +func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return +} + +// CHECK-LABEL: func.func @unroll_vector_load( +// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> { +func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> + // CHECK: return %[[V3]] : vector<2x2xf16> + %c1 = arith.constant 1 : index + %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return %0 : vector<2x2xf16> +} + +// CHECK-LABEL: func.func @unroll_vector_store( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) { +func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + %c1 = arith.constant 1 : index + vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index eda2594fbc7c7..b2b2b4ece22cd 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -289,6 +289,44 @@ struct TestVectorTransferUnrollingPatterns llvm::cl::init(false)}; }; +struct TestVectorLoadStoreUnrollPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorLoadStoreUnrollPatterns) + + StringRef getArgument() const final { + return "test-vector-load-store-unroll"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for vector.load and vector.store ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + + // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors + vector::UnrollVectorOptions options; + options.setFilterConstraint([](Operation *op) { + if (auto loadOp = dyn_cast(op)) + return success(loadOp.getType().getRank() > 1); + if (auto storeOp = dyn_cast(op)) + return success(storeOp.getVectorType().getRank() > 1); + return failure(); + }); + + vector::populateVectorUnrollPatterns(patterns, options); + + // Apply the patterns + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestScalarVectorTransferLoweringPatterns : public PassWrapper> { @@ -1033,6 +1071,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); From 9d91abe8417b56bfb6b7e220b8fbbd050b8e03da Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 9 Jun 2025 20:43:56 +0000 Subject: [PATCH 2/7] Clean up --- .../Vector/vector-load-store-unroll.mlir | 73 ------------------- .../Dialect/Vector/vector-unroll-options.mlir | 73 +++++++++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 50 +++---------- 3 files changed, 83 insertions(+), 113 deletions(-) delete mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir deleted file mode 100644 index 3135268b8d61b..0000000000000 --- a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s - -// CHECK-LABEL: func.func @unroll_2D_vector_load( -// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { -func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { - // CHECK: %[[C3:.*]] = arith.constant 3 : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16> - // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: return %[[V7]] : vector<4x4xf16> - %c0 = arith.constant 0 : index - %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> - return %0 : vector<4x4xf16> -} - -// CHECK-LABEL: func.func @unroll_2D_vector_store( -// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { -func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { - // CHECK: %[[C3:.*]] = arith.constant 3 : index - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - %c0 = arith.constant 0 : index - vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> - return -} - -// CHECK-LABEL: func.func @unroll_vector_load( -// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> { -func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> - // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> - // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> - // CHECK: return %[[V3]] : vector<2x2xf16> - %c1 = arith.constant 1 : index - %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return %0 : vector<2x2xf16> -} - -// CHECK-LABEL: func.func @unroll_vector_store( -// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) { -func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16> - // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16> - // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - %c1 = arith.constant 1 : index - vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return -} diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index fbb178fb49d87..efb709e41a69c 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -378,3 +378,76 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector // CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32> // CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK: return [[r3]] : vector<4x4xf32> + + +// CHECK-LABEL: func.func @unroll_2D_vector_load( +// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { +func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: return %[[V7]] : vector<4x4xf16> + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// CHECK-LABEL: func.func @unroll_2D_vector_store( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { +func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16> + // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return +} + +// CHECK-LABEL: func.func @unroll_vector_load( +// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> { +func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> + // CHECK: return %[[V3]] : vector<2x2xf16> + %c1 = arith.constant 1 : index + %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return %0 : vector<2x2xf16> +} + +// CHECK-LABEL: func.func @unroll_vector_store( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) { +func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + %c1 = arith.constant 1 : index + vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 8014362a1a6ec..023a6706b58be 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + if (auto loadOp = dyn_cast(op)) + return success(loadOp.getType().getRank() > 1); + if (auto storeOp = dyn_cast(op)) + return success(storeOp.getVectorType().getRank() > 1); + return failure(); + })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = [](Operation *op) -> std::optional> { @@ -292,44 +302,6 @@ struct TestVectorTransferUnrollingPatterns llvm::cl::init(false)}; }; -struct TestVectorLoadStoreUnrollPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorLoadStoreUnrollPatterns) - - StringRef getArgument() const final { - return "test-vector-load-store-unroll"; - } - StringRef getDescription() const final { - return "Test unrolling patterns for vector.load and vector.store ops"; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - - // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors - vector::UnrollVectorOptions options; - options.setFilterConstraint([](Operation *op) { - if (auto loadOp = dyn_cast(op)) - return success(loadOp.getType().getRank() > 1); - if (auto storeOp = dyn_cast(op)) - return success(storeOp.getVectorType().getRank() > 1); - return failure(); - }); - - vector::populateVectorUnrollPatterns(patterns, options); - - // Apply the patterns - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestScalarVectorTransferLoweringPatterns : public PassWrapper> { @@ -1070,8 +1042,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration(); From 3f4094825463cc592415dc90f03013b9db5a5230 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 10 Jun 2025 18:36:24 +0000 Subject: [PATCH 3/7] Address feedback --- .../Vector/Transforms/VectorUnroll.cpp | 31 +++++++++++-- .../Dialect/Vector/vector-unroll-options.mlir | 44 +++++++++++-------- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index e912a6ef29b21..6780a898b7fd5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -54,10 +54,10 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, return slicedIndices; } -// compute the new indices for vector.load/store by adding offsets to -// originalIndices. +// Compute the new indices for vector.load/store by adding `offsets` to +// `originalIndices`. // It assumes m <= n (m = offsets.size(), n = originalIndices.size()) -// Last m of originalIndices will be updated. +// Last m of `originalIndices` will be updated. static SmallVector computeIndices(PatternRewriter &rewriter, Location loc, ArrayRef originalIndices, @@ -658,6 +658,20 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +// clang-format off +// This pattern unrolls the vector load into multiple 1D vector loads by +// extracting slices from the base memory and inserting them into the result +// vector using vector.insert_strided_slice. +// Following, +// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32> +// is converted to : +// %cst = arith.constant dense<0.0> : vector<4x4xf32> +// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32> +// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> +// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32> +// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> +// ... +// clang-format on struct UnrollLoadPattern : public OpRewritePattern { UnrollLoadPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, @@ -707,6 +721,17 @@ struct UnrollLoadPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +// This pattern unrolls the vector store into multiple 1D vector stores by +// extracting slices from the source vector and storing them into the +// destination. +// Following, +// vector.store %source, %base[%indices] : vector<4x4xf32> +// is converted to : +// %slice_0 = vector.extract %source[0] : vector<4xf32> +// vector.store %slice_0, %base[%indices] : vector<4xf32> +// %slice_1 = vector.extract %source[1] : vector<4xf32> +// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> +// ... struct UnrollStorePattern : public OpRewritePattern { UnrollStorePattern(MLIRContext *context, const vector::UnrollVectorOptions &options, diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index efb709e41a69c..23344a400bcc7 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -380,9 +380,14 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector // CHECK: return [[r3]] : vector<4x4xf32> -// CHECK-LABEL: func.func @unroll_2D_vector_load( +func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> { + %c0 = arith.constant 0 : index + %0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// CHECK-LABEL: func.func @vector_load_2D( // CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { -func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -397,14 +402,16 @@ func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> // CHECK: return %[[V7]] : vector<4x4xf16> + + +func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { %c0 = arith.constant 0 : index - %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> - return %0 : vector<4x4xf16> + vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return } -// CHECK-LABEL: func.func @unroll_2D_vector_store( +// CHECK-LABEL: func.func @vector_store_2D( // CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { -func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -417,14 +424,16 @@ func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16> // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16> // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - %c0 = arith.constant 0 : index - vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> - return + + +func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> { + %c1 = arith.constant 1 : index + %0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return %0 : vector<2x2xf16> } -// CHECK-LABEL: func.func @unroll_vector_load( +// CHECK-LABEL: func.func @vector_load_4D_to_2D( // CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> { -func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> { // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> @@ -433,21 +442,18 @@ func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> { // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> // CHECK: return %[[V3]] : vector<2x2xf16> + +func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) { %c1 = arith.constant 1 : index - %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return %0 : vector<2x2xf16> + vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> + return } -// CHECK-LABEL: func.func @unroll_vector_store( +// CHECK-LABEL: func.func @vector_store_2D_to_4D( // CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) { -func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) { // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16> // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16> // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - %c1 = arith.constant 1 : index - vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return -} From 5a2070b794db9932e1b24fd53e20a662f8212e2a Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 11 Jun 2025 22:30:05 +0000 Subject: [PATCH 4/7] Simplify computeIndices --- .../Vector/Transforms/VectorUnroll.cpp | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 6780a898b7fd5..57e36e91f6e5e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -54,10 +54,9 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, return slicedIndices; } -// Compute the new indices for vector.load/store by adding `offsets` to -// `originalIndices`. -// It assumes m <= n (m = offsets.size(), n = originalIndices.size()) -// Last m of `originalIndices` will be updated. +// Compute the new indices by adding `offsets` to `originalIndices`. +// If m < n (m = offsets.size(), n = originalIndices.size()), +// then only the trailing m values in `originalIndices` are updated. static SmallVector computeIndices(PatternRewriter &rewriter, Location loc, ArrayRef originalIndices, @@ -65,21 +64,17 @@ static SmallVector computeIndices(PatternRewriter &rewriter, assert(offsets.size() <= originalIndices.size() && "Offsets should not exceed the number of original indices"); SmallVector indices(originalIndices); - auto originalIter = originalIndices.rbegin(); - auto offsetsIter = offsets.rbegin(); - auto indicesIter = indices.rbegin(); - while (offsetsIter != offsets.rend()) { - Value original = *originalIter; - int64_t offset = *offsetsIter; - if (offset != 0) - *indicesIter = rewriter.create( - loc, original, rewriter.create(loc, offset)); - originalIter++; - offsetsIter++; - indicesIter++; + + auto start = indices.size() - offsets.size(); + for (auto [i, offset] : llvm::enumerate(offsets)) { + if (offset != 0) { + indices[start + i] = rewriter.create( + loc, originalIndices[start + i], + rewriter.create(loc, offset)); + } } return indices; -}; +} // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. @@ -658,7 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; -// clang-format off // This pattern unrolls the vector load into multiple 1D vector loads by // extracting slices from the base memory and inserting them into the result // vector using vector.insert_strided_slice. @@ -667,11 +661,13 @@ struct UnrollGatherPattern : public OpRewritePattern { // is converted to : // %cst = arith.constant dense<0.0> : vector<4x4xf32> // %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32> -// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> -// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32> -// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> +// %result_0 = vector.insert_strided_slice %slice_0, %cst +// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> +// %slice_1 = vector.load %base[%indices + 1] +// : memref<4x4xf32>, vector<4xf32> +// %result_1 = vector.insert_strided_slice %slice_1, %result_0 +// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> // ... -// clang-format on struct UnrollLoadPattern : public OpRewritePattern { UnrollLoadPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, From 57cc380c625d9b1c344240d3715025f773ed9c46 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 19 Jun 2025 17:06:48 +0000 Subject: [PATCH 5/7] Use unroll options --- .../mlir/Dialect/Vector/IR/VectorOps.td | 8 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++ .../Vector/Transforms/VectorUnroll.cpp | 74 +++++++------------ .../Dialect/Vector/vector-unroll-options.mlir | 69 ++++------------- .../Dialect/Vector/TestVectorTransforms.cpp | 12 +-- 5 files changed, 56 insertions(+), 115 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..d05fea3a5d755 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1673,7 +1673,9 @@ def Vector_TransferWriteOp : let hasVerifier = 1; } -def Vector_LoadOp : Vector_Op<"load"> { +def Vector_LoadOp : Vector_Op<"load", [ + DeclareOpInterfaceMethods, + ]> { let summary = "reads an n-D slice of memory into an n-D vector"; let description = [{ The 'vector.load' operation reads an n-D slice of memory into an n-D @@ -1759,7 +1761,9 @@ def Vector_LoadOp : Vector_Op<"load"> { "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; } -def Vector_StoreOp : Vector_Op<"store"> { +def Vector_StoreOp : Vector_Op<"store", [ + DeclareOpInterfaceMethods, + ]> { let summary = "writes an n-D vector to an n-D slice of memory"; let description = [{ The 'vector.store' operation writes an n-D vector to an n-D slice of memory. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3179b4f975404..1d0d0ec3c2fc9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5266,6 +5266,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) { return OpFoldResult(); } +std::optional> LoadOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -5301,6 +5305,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor, return memref::foldMemRefCast(*this); } +std::optional> StoreOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 57e36e91f6e5e..baee341f6768b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -653,21 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; -// This pattern unrolls the vector load into multiple 1D vector loads by -// extracting slices from the base memory and inserting them into the result -// vector using vector.insert_strided_slice. -// Following, -// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32> -// is converted to : -// %cst = arith.constant dense<0.0> : vector<4x4xf32> -// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32> -// %result_0 = vector.insert_strided_slice %slice_0, %cst -// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> -// %slice_1 = vector.load %base[%indices + 1] -// : memref<4x4xf32>, vector<4xf32> -// %result_1 = vector.insert_strided_slice %slice_1, %result_0 -// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32> -// ... struct UnrollLoadPattern : public OpRewritePattern { UnrollLoadPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, @@ -677,21 +662,16 @@ struct UnrollLoadPattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { VectorType vecType = loadOp.getVectorType(); - // Only unroll >1D loads if (vecType.getRank() <= 1) return failure(); + auto targetShape = getTargetShape(options, loadOp); + if (!targetShape) + return failure(); + Location loc = loadOp.getLoc(); ArrayRef originalShape = vecType.getShape(); - - // Target type is a 1D vector of the innermost dimension. - auto targetType = - VectorType::get(originalShape.back(), vecType.getElementType()); - - // Extend the targetShape to the same rank of original shape by padding 1s - // for leading dimensions for convenience of computing offsets - SmallVector targetShape(originalShape.size(), 1); - targetShape.back() = originalShape.back(); + SmallVector strides(targetShape->size(), 1); Value result = rewriter.create( loc, vecType, rewriter.getZeroAttr(vecType)); @@ -699,15 +679,20 @@ struct UnrollLoadPattern : public OpRewritePattern { SmallVector originalIndices(loadOp.getIndices().begin(), loadOp.getIndices().end()); + SmallVector loopOrder = + getUnrollOrder(originalShape.size(), loadOp, options); + + auto targetVecType = + VectorType::get(*targetShape, vecType.getElementType()); + for (SmallVector offsets : - StaticTileOffsetRange(originalShape, targetShape)) { + StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = computeIndices(rewriter, loc, originalIndices, offsets); - Value slice = rewriter.create(loc, targetType, + Value slice = rewriter.create(loc, targetVecType, loadOp.getBase(), indices); - // Insert the slice into the result at the correct position. result = rewriter.createOrFold( - loc, slice, result, offsets, SmallVector({1})); + loc, slice, result, offsets, strides); } rewriter.replaceOp(loadOp, result); return success(); @@ -717,17 +702,6 @@ struct UnrollLoadPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; -// This pattern unrolls the vector store into multiple 1D vector stores by -// extracting slices from the source vector and storing them into the -// destination. -// Following, -// vector.store %source, %base[%indices] : vector<4x4xf32> -// is converted to : -// %slice_0 = vector.extract %source[0] : vector<4xf32> -// vector.store %slice_0, %base[%indices] : vector<4xf32> -// %slice_1 = vector.extract %source[1] : vector<4xf32> -// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> -// ... struct UnrollStorePattern : public OpRewritePattern { UnrollStorePattern(MLIRContext *context, const vector::UnrollVectorOptions &options, @@ -737,17 +711,16 @@ struct UnrollStorePattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { VectorType vecType = storeOp.getVectorType(); - // Only unroll >1D stores. if (vecType.getRank() <= 1) return failure(); + auto targetShape = getTargetShape(options, storeOp); + if (!targetShape) + return failure(); + Location loc = storeOp.getLoc(); ArrayRef originalShape = vecType.getShape(); - - // Extend the targetShape to the same rank of original shape by padding 1s - // for leading dimensions for convenience of computing offsets - SmallVector targetShape(originalShape.size(), 1); - targetShape.back() = originalShape.back(); + SmallVector strides(targetShape->size(), 1); Value base = storeOp.getBase(); Value vector = storeOp.getValueToStore(); @@ -755,12 +728,15 @@ struct UnrollStorePattern : public OpRewritePattern { SmallVector originalIndices(storeOp.getIndices().begin(), storeOp.getIndices().end()); + SmallVector loopOrder = + getUnrollOrder(originalShape.size(), storeOp, options); + for (SmallVector offsets : - StaticTileOffsetRange(originalShape, targetShape)) { + StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = computeIndices(rewriter, loc, originalIndices, offsets); - offsets.pop_back(); - Value slice = rewriter.create(loc, vector, offsets); + Value slice = rewriter.createOrFold( + loc, vector, offsets, *targetShape, strides); rewriter.create(loc, slice, base, indices); } rewriter.eraseOp(storeOp); diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 23344a400bcc7..e129cd5c40b9c 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -388,19 +388,17 @@ func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> { // CHECK-LABEL: func.func @vector_load_2D( // CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> { - // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16> - // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> - // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16> + // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> + // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16> // CHECK: return %[[V7]] : vector<4x4xf16> @@ -412,48 +410,13 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { // CHECK-LABEL: func.func @vector_store_2D( // CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) { - // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16> - // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - - -func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> { - %c1 = arith.constant 1 : index - %0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return %0 : vector<2x2xf16> -} - -// CHECK-LABEL: func.func @vector_load_4D_to_2D( -// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> - // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> - // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16> - // CHECK: return %[[V3]] : vector<2x2xf16> - -func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) { - %c1 = arith.constant 1 : index - vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16> - return -} - -// CHECK-LABEL: func.func @vector_store_2D_to_4D( -// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16> - // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> - // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16> - // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16> + // CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> + // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 023a6706b58be..fc75b273a057b 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -163,7 +163,7 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success( isa(op)); + vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() @@ -178,16 +178,6 @@ struct TestVectorUnrollingPatterns return success(isa(op)); })); - populateVectorUnrollPatterns( - patterns, UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint([](Operation *op) { - if (auto loadOp = dyn_cast(op)) - return success(loadOp.getType().getRank() > 1); - if (auto storeOp = dyn_cast(op)) - return success(storeOp.getVectorType().getRank() > 1); - return failure(); - })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = [](Operation *op) -> std::optional> { From 2731a187fef41b96bcf4501d3c0f1488ccc6e644 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 20 Jun 2025 14:37:52 +0000 Subject: [PATCH 6/7] Formatting --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index fc75b273a057b..71000b98fb8f9 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -163,7 +163,8 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success( isa(op)); + vector::BroadcastOp, vector::LoadOp, vector::StoreOp>( + op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() From 291693297e009193f2f74c73ba4b0f67aa5667ca Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 20 Jun 2025 20:28:19 +0000 Subject: [PATCH 7/7] Address comments --- .../Vector/Transforms/VectorUnroll.cpp | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index baee341f6768b..693f4f955994d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -57,10 +57,10 @@ static SmallVector sliceTransferIndices(ArrayRef elementOffsets, // Compute the new indices by adding `offsets` to `originalIndices`. // If m < n (m = offsets.size(), n = originalIndices.size()), // then only the trailing m values in `originalIndices` are updated. -static SmallVector computeIndices(PatternRewriter &rewriter, - Location loc, - ArrayRef originalIndices, - ArrayRef offsets) { +static SmallVector sliceLoadStoreIndices(PatternRewriter &rewriter, + Location loc, + OperandRange originalIndices, + ArrayRef offsets) { assert(offsets.size() <= originalIndices.size() && "Offsets should not exceed the number of original indices"); SmallVector indices(originalIndices); @@ -662,8 +662,6 @@ struct UnrollLoadPattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { VectorType vecType = loadOp.getVectorType(); - if (vecType.getRank() <= 1) - return failure(); auto targetShape = getTargetShape(options, loadOp); if (!targetShape) @@ -676,9 +674,6 @@ struct UnrollLoadPattern : public OpRewritePattern { Value result = rewriter.create( loc, vecType, rewriter.getZeroAttr(vecType)); - SmallVector originalIndices(loadOp.getIndices().begin(), - loadOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalShape.size(), loadOp, options); @@ -688,11 +683,11 @@ struct UnrollLoadPattern : public OpRewritePattern { for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = - computeIndices(rewriter, loc, originalIndices, offsets); - Value slice = rewriter.create(loc, targetVecType, - loadOp.getBase(), indices); + sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets); + Value slicedLoad = rewriter.create( + loc, targetVecType, loadOp.getBase(), indices); result = rewriter.createOrFold( - loc, slice, result, offsets, strides); + loc, slicedLoad, result, offsets, strides); } rewriter.replaceOp(loadOp, result); return success(); @@ -711,8 +706,6 @@ struct UnrollStorePattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { VectorType vecType = storeOp.getVectorType(); - if (vecType.getRank() <= 1) - return failure(); auto targetShape = getTargetShape(options, storeOp); if (!targetShape) @@ -725,16 +718,13 @@ struct UnrollStorePattern : public OpRewritePattern { Value base = storeOp.getBase(); Value vector = storeOp.getValueToStore(); - SmallVector originalIndices(storeOp.getIndices().begin(), - storeOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalShape.size(), storeOp, options); for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { SmallVector indices = - computeIndices(rewriter, loc, originalIndices, offsets); + sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets); Value slice = rewriter.createOrFold( loc, vector, offsets, *targetShape, strides); rewriter.create(loc, slice, base, indices);