-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Add support for lowering n-D vector.to_elements op. #156992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThe revision adds a pattern that flattens 2 or more dimensional It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests. It recovers the e2e lowering breakage from b4c31dc on LLVM path. Full diff: https://github.com/llvm/llvm-project/pull/156992.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e0f744841db2b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..014034b8f9737
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,52 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "the rank is already less than or equal to 1");
+ if (vecType.getNumScalableDims() > 0)
+ return rewriter.notifyMatchFailure(
+ op, "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ vec1DType, op.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+ shapeCast);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..bf4b05f7874de 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We flatten multi-dimensional to_elements ops with pattern
+// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
+// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
+// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
+// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
+// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..560a1331bdaf0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
}
};
+struct TestFlattenVectorToElements
+ : public PassWrapper<TestFlattenVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-flatten-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test flattening patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestFlattenVectorToElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
|
The revision adds a pattern that flattens 2 or more dimensional `vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`. It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests. It recovers the e2e lowering breakage from llvm@b4c31dc on LLVM path. Signed-off-by: hanhanW <[email protected]>
a82feb4 to
d661413
Compare
|
cc @yangtetris (I can't add you to reviews) |
|
|
||
| /// Flattens 2 or more dimensional `vector.to_elements` ops by | ||
| /// `vector.shape_cast` + `vector.to_elements`. | ||
| struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> { | |
| struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> { |
| if (vecType.getNumScalableDims() > 0) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "scalable vector is not yet supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does from_elements support scalable vectors at all? https://mlir.llvm.org/docs/Dialects/Vector/#results-11
I think we can make it an assertion
| void runOnOperation() override { | ||
| RewritePatternSet patterns(&getContext()); | ||
| populateVectorToElementsLoweringPatterns(patterns); | ||
| (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use walkAndApplyPatterns here since we never have to revisit newly created ops
It carries a cherry-pick fix that gets the operands from the adaptor: - iree-org/llvm-project@8b88014 Changes: - Update most lit tests to check `vector.from_elements`. - Add unrolling patterns to the final conversion. - Implement n-D `vector::ToElementsOp` lowering, which will be dropped after llvm/llvm-project#156992 is landed. It should be added to all the backends, but somehow only AMDGPU backend needs the pattern. The other backends may address the issue via specialized tiling config + dropping vector unit dim patterns. --------- Signed-off-by: hanhanW <[email protected]>
|
Thanks for the fix! To be honest, I didn't realize that this canonicalization pattern also broke vector.to_elements... Should we also add |
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be implemented the same way #151175 is implemented. We should be implementing a unrolling pattern for this. The rest of the ops do unrolling by default and we should keep it consistent.
|
I think flattening and unrolling are different approach, and people can make their own decisions. There may be cases that flattening innermost dims and unrolling the rest, but I don't have a use case so far. We will need unrolling for sure, and @amd-eochoalo will work on it, so I'll leave it to him. |
|
Yes, we need both approaches and none of them are implemented for this op. Please, let me know if you plan to work on both or just one so that we plan accordingly. |
|
@dcaballe I am building on top of this PR. I thought both patterns could be merged. I will be opening it up for review in a couple of minutes. |
It carries a cherry-pick fix that gets the operands from the adaptor: - iree-org/llvm-project@8b88014 Changes: - Update most lit tests to check `vector.from_elements`. - Add unrolling patterns to the final conversion. - Implement n-D `vector::ToElementsOp` lowering, which will be dropped after llvm/llvm-project#156992 is landed. It should be added to all the backends, but somehow only AMDGPU backend needs the pattern. The other backends may address the issue via specialized tiling config + dropping vector unit dim patterns. --------- Signed-off-by: hanhanW <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
The revision adds a pattern that flattens 2 or more dimensional
vector.to_elementsops byvector.shape_cast+vector.to_elements.It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests.
It recovers the e2e lowering breakage from b4c31dc on LLVM path.