Skip to content

Commit d661413

Browse files
committed
[mlir][vector] Add support for lowering n-D vector.to_elements op.
The revision adds a pattern that flattens 2 or more dimensional `vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`. It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests. It recovers the e2e lowering breakage from b4c31dc on LLVM path. Signed-off-by: hanhanW <[email protected]>
1 parent dc2ed00 commit d661413

File tree

7 files changed

+146
-0
lines changed

7 files changed

+146
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
311311
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
312312
PatternBenefit benefit = 1);
313313

314+
/// Populate the pattern set with the following patterns:
315+
///
316+
/// [FlattenToElements]
317+
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
318+
PatternBenefit benefit = 1);
319+
314320
/// Populate the pattern set with the following patterns:
315321
///
316322
/// [ContractionOpToMatmulOpLowering]

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9595
populateVectorRankReducingFMAPattern(patterns);
9696
populateVectorGatherLoweringPatterns(patterns);
9797
populateVectorFromElementsLoweringPatterns(patterns);
98+
populateVectorToElementsLoweringPatterns(patterns);
9899
if (armI8MM) {
99100
if (armNeon)
100101
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1111
LowerVectorScan.cpp
1212
LowerVectorShapeCast.cpp
1313
LowerVectorStep.cpp
14+
LowerVectorToElements.cpp
1415
LowerVectorToFromElementsToShuffleTree.cpp
1516
LowerVectorTransfer.cpp
1617
LowerVectorTranspose.cpp
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.to_elements' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
17+
#define DEBUG_TYPE "lower-vector-to-elements"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
23+
/// Flattens 2 or more dimensional `vector.to_elements` ops by
24+
/// `vector.shape_cast` + `vector.to_elements`.
25+
struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
26+
using OpRewritePattern::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
29+
PatternRewriter &rewriter) const override {
30+
VectorType vecType = op.getSource().getType();
31+
if (vecType.getRank() <= 1)
32+
return rewriter.notifyMatchFailure(
33+
op, "the rank is already less than or equal to 1");
34+
if (vecType.getNumScalableDims() > 0)
35+
return rewriter.notifyMatchFailure(
36+
op, "scalable vector is not yet supported");
37+
auto vec1DType =
38+
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
39+
Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
40+
vec1DType, op.getSource());
41+
rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
42+
shapeCast);
43+
return success();
44+
}
45+
};
46+
47+
} // namespace
48+
49+
void mlir::vector::populateVectorToElementsLoweringPatterns(
50+
RewritePatternSet &patterns, PatternBenefit benefit) {
51+
patterns.add<FlattenToElements>(patterns.getContext(), benefit);
52+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
17741774
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
17751775
return %0 : vector<2x1x2xf32>
17761776
}
1777+
1778+
// -----
1779+
1780+
//===----------------------------------------------------------------------===//
1781+
// vector.to_elements
1782+
//===----------------------------------------------------------------------===//
1783+
1784+
// CHECK-LABEL: func @to_elements_1d(
1785+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
1786+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
1787+
// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
1788+
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
1789+
// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
1790+
// CHECK: return %[[V0]], %[[V1]]
1791+
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
1792+
%0:2 = vector.to_elements %arg0 : vector<2xf32>
1793+
return %0#0, %0#1 : f32, f32
1794+
}
1795+
1796+
// -----
1797+
1798+
// NOTE: We flatten multi-dimensional to_elements ops with pattern
1799+
// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
1800+
1801+
// CHECK-LABEL: func @to_elements_2d(
1802+
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
1803+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
1804+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
1805+
// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
1806+
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
1807+
// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
1808+
// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
1809+
// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
1810+
// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
1811+
// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
1812+
// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
1813+
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
1814+
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
1815+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
1816+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @to_elements_1d(
4+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
5+
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
6+
// CHECK: return %[[RES]]#0, %[[RES]]#1
7+
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
8+
%0:2 = vector.to_elements %arg0 : vector<2xf32>
9+
return %0#0, %0#1 : f32, f32
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: func.func @to_elements_2d(
15+
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
16+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
17+
// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
18+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
19+
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
20+
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
21+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
22+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
808808
}
809809
};
810810

811+
struct TestFlattenVectorToElements
812+
: public PassWrapper<TestFlattenVectorToElements,
813+
OperationPass<func::FuncOp>> {
814+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
815+
816+
StringRef getArgument() const final {
817+
return "test-flatten-vector-to-elements";
818+
}
819+
StringRef getDescription() const final {
820+
return "Test flattening patterns for to_elements ops";
821+
}
822+
void getDependentDialects(DialectRegistry &registry) const override {
823+
registry.insert<func::FuncDialect, vector::VectorDialect>();
824+
}
825+
826+
void runOnOperation() override {
827+
RewritePatternSet patterns(&getContext());
828+
populateVectorToElementsLoweringPatterns(patterns);
829+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
830+
}
831+
};
832+
811833
struct TestFoldArithExtensionIntoVectorContractPatterns
812834
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
813835
OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
10831105

10841106
PassRegistration<TestUnrollVectorFromElements>();
10851107

1108+
PassRegistration<TestFlattenVectorToElements>();
1109+
10861110
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
10871111

10881112
PassRegistration<TestVectorEmulateMaskedLoadStore>();

0 commit comments

Comments
 (0)