-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Add vector.to_elements unrolling #157142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add vector.to_elements unrolling #157142
Conversation
The revision adds a pattern that flattens 2 or more dimensional `vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`. It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests. It recovers the e2e lowering breakage from llvm@b4c31dc on LLVM path. Signed-off-by: hanhanW <[email protected]>
Extract n vector<axbx...> from vector<nxaxbx...> This patch adds a utility function that will unroll vector values. This is different from the current utility function that focuses on unrolling vector operations.
|
@llvm/pr-subscribers-mlir-vector Author: Erick Ochoa Lopez (amd-eochoalo) ChangesThis PR adds support for unrolling It transforms %0:8 = vector.to_elements %v : vector<2x2x2xf32>to %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
%v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
%0:4 = vector.to_elements %v0 : vector<2x2xf32>
%1:4 = vector.to_elements %v1 : vector<2x2xf32>
// %0:8 = %0:4 - %1:4This pattern will be applied until there are only 1-D vectors left. Full diff: https://github.com/llvm/llvm-project/pull/157142.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,18 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &values);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..b86e8b274770f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,83 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> vectors;
+ LogicalResult match =
+ mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+ if (failed(match)) {
+ return match;
+ }
+
+ // May be large vector.
+ std::vector<Value> results;
+ for (const auto &vector : vectors) {
+ // we need to replace the current result
+ auto subElements =
+ rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+ results.insert(results.end(), subElements.getResults().begin(),
+ subElements.getResults().end());
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "the rank is already less than or equal to 1");
+ if (vecType.getNumScalableDims() > 0)
+ return rewriter.notifyMatchFailure(
+ op, "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ vec1DType, op.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+ shapeCast);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &subvectors) {
+ assert(isa<VectorType>(vector.getType()) && "expected vector type");
+ VectorType ty = cast<VectorType>(vector.getType());
+ Location loc = vector.getLoc();
+ if (ty.getRank() < 2)
+ return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (ty.getScalableDims().front())
+ return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+ // We just need zero indices for the all dimensions except the leading one.
+ for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+ subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+ }
+
+ return success();
+}
+
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..e302dbd174322
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,50 @@ struct TestUnrollVectorFromElements
}
};
+struct TestFlattenVectorToElements
+ : public PassWrapper<TestFlattenVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-flatten-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test flattening patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsFlatteningPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+struct TestUnrollVectorToElements
+ : public PassWrapper<TestUnrollVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1083,6 +1127,10 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestUnrollVectorToElements>();
+
+ PassRegistration<TestFlattenVectorToElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
|
|
@llvm/pr-subscribers-mlir Author: Erick Ochoa Lopez (amd-eochoalo) ChangesThis PR adds support for unrolling It transforms %0:8 = vector.to_elements %v : vector<2x2x2xf32>to %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
%v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
%0:4 = vector.to_elements %v0 : vector<2x2xf32>
%1:4 = vector.to_elements %v1 : vector<2x2xf32>
// %0:8 = %0:4 - %1:4This pattern will be applied until there are only 1-D vectors left. Full diff: https://github.com/llvm/llvm-project/pull/157142.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,18 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &values);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..b86e8b274770f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,83 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> vectors;
+ LogicalResult match =
+ mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+ if (failed(match)) {
+ return match;
+ }
+
+ // May be large vector.
+ std::vector<Value> results;
+ for (const auto &vector : vectors) {
+ // we need to replace the current result
+ auto subElements =
+ rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+ results.insert(results.end(), subElements.getResults().begin(),
+ subElements.getResults().end());
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "the rank is already less than or equal to 1");
+ if (vecType.getNumScalableDims() > 0)
+ return rewriter.notifyMatchFailure(
+ op, "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ vec1DType, op.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+ shapeCast);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &subvectors) {
+ assert(isa<VectorType>(vector.getType()) && "expected vector type");
+ VectorType ty = cast<VectorType>(vector.getType());
+ Location loc = vector.getLoc();
+ if (ty.getRank() < 2)
+ return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (ty.getScalableDims().front())
+ return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+ // We just need zero indices for the all dimensions except the leading one.
+ for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+ subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+ }
+
+ return success();
+}
+
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..e302dbd174322
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,50 @@ struct TestUnrollVectorFromElements
}
};
+struct TestFlattenVectorToElements
+ : public PassWrapper<TestFlattenVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-flatten-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test flattening patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsFlatteningPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+struct TestUnrollVectorToElements
+ : public PassWrapper<TestUnrollVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1083,6 +1127,10 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestUnrollVectorToElements>();
+
+ PassRegistration<TestFlattenVectorToElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
|
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but let's wait for a at least one more review
newling
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make separate PRs for unrolling and flattening?
newling
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level question: Is the motivation for this PR that without it we currently fail to lower to rank-2+ from_elements ops from Vector to LLVM?
Parameters are now: * using TypedValue<VectorType> instead of just Value * using RewriterBase class. Return types are: * changed to FailureOr<SmallValue<Value>> instead of passing a Value as a parameter and returning Logical.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just some minor comments.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My final asks are minor, so approving as it, but please address them before landing.
Thanks for working on this! 🙏🏻
LGTM % outstanding asks
Co-authored-by: Jakub Kuderski <[email protected]>
…58158) `mlir/test/Dialect/Vector/td/unroll-elements.mlir` is fed as a data dependency into`mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir` added in [#157142](#157142). The Bazel rule here automatically picks up all mlir files as tests, which leads to `vector-to-elements-lowering` failing.
…as data (#158158) `mlir/test/Dialect/Vector/td/unroll-elements.mlir` is fed as a data dependency into`mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir` added in [#157142](llvm/llvm-project#157142). The Bazel rule here automatically picks up all mlir files as tests, which leads to `vector-to-elements-lowering` failing.
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling.
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <[email protected]>
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <[email protected]>
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <[email protected]>
This PR adds support for unrolling
vector.to_element's source operand.It transforms
to
This pattern will be applied until there are only 1-D vectors left.