Skip to content

Commit 9d19250

Browse files
amd-eochoalohanhanWkuhar
authored
[mlir][vector] Add vector.to_elements unrolling (#157142)
This PR adds support for unrolling `vector.to_element`'s source operand. It transforms ```mlir %0:8 = vector.to_elements %v : vector<2x2x2xf32> ``` to ```mlir %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32> %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32> %0:4 = vector.to_elements %v0 : vector<2x2xf32> %1:4 = vector.to_elements %v1 : vector<2x2xf32> // %0:8 = %0:4 - %1:4 ``` This pattern will be applied until there are only 1-D vectors left. --------- Signed-off-by: hanhanW <[email protected]> Co-authored-by: hanhanW <[email protected]> Co-authored-by: Jakub Kuderski <[email protected]>
1 parent ddb2e34 commit 9d19250

File tree

14 files changed

+225
-0
lines changed

14 files changed

+225
-0
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
265265
let assemblyFormat = "attr-dict";
266266
}
267267

268+
def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
269+
"apply_patterns.vector.unroll_to_elements",
270+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
271+
let description = [{
272+
Indicates that vector to_elements operations should be unrolled
273+
along the outermost dimension.
274+
}];
275+
276+
let assemblyFormat = "attr-dict";
277+
}
278+
268279
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
269280
"apply_patterns.vector.lower_scan",
270281
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
311311
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
312312
PatternBenefit benefit = 1);
313313

314+
/// Populate the pattern set with the following patterns:
315+
///
316+
/// [UnrollToElements]
317+
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
318+
PatternBenefit benefit = 1);
319+
314320
/// Populate the pattern set with the following patterns:
315321
///
316322
/// [ContractionOpToMatmulOpLowering]

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ using UnrollVectorOpFn =
255255
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
256256
UnrollVectorOpFn unrollFn);
257257

258+
/// Generic utility for unrolling values of type vector<NxAxBx...>
259+
/// to N values of type vector<AxBx...> using vector.extract. If the input
260+
/// is rank-1 or has leading scalable dimension, failure is returned.
261+
FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
262+
RewriterBase &);
263+
258264
} // namespace vector
259265

260266
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9595
populateVectorRankReducingFMAPattern(patterns);
9696
populateVectorGatherLoweringPatterns(patterns);
9797
populateVectorFromElementsLoweringPatterns(patterns);
98+
populateVectorToElementsLoweringPatterns(patterns);
9899
if (armI8MM) {
99100
if (armNeon)
100101
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
144144
vector::populateVectorFromElementsLoweringPatterns(patterns);
145145
}
146146

147+
void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
148+
RewritePatternSet &patterns) {
149+
vector::populateVectorToElementsLoweringPatterns(patterns);
150+
}
151+
147152
void transform::ApplyLowerScanPatternsOp::populatePatterns(
148153
RewritePatternSet &patterns) {
149154
vector::populateVectorScanLoweringPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1111
LowerVectorScan.cpp
1212
LowerVectorShapeCast.cpp
1313
LowerVectorStep.cpp
14+
LowerVectorToElements.cpp
1415
LowerVectorToFromElementsToShuffleTree.cpp
1516
LowerVectorTransfer.cpp
1617
LowerVectorTranspose.cpp
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.to_elements' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
17+
#define DEBUG_TYPE "lower-vector-to-elements"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
23+
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
24+
using OpRewritePattern::OpRewritePattern;
25+
26+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
27+
PatternRewriter &rewriter) const override {
28+
29+
TypedValue<VectorType> source = op.getSource();
30+
FailureOr<SmallVector<Value>> result =
31+
vector::unrollVectorValue(source, rewriter);
32+
if (failed(result)) {
33+
return failure();
34+
}
35+
SmallVector<Value> vectors = *result;
36+
37+
SmallVector<Value> results;
38+
for (const Value &vector : vectors) {
39+
auto subElements =
40+
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
41+
llvm::append_range(results, subElements.getResults());
42+
}
43+
rewriter.replaceOp(op, results);
44+
return success();
45+
}
46+
};
47+
48+
} // namespace
49+
50+
void mlir::vector::populateVectorToElementsLoweringPatterns(
51+
RewritePatternSet &patterns, PatternBenefit benefit) {
52+
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
53+
}

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
393393
return success();
394394
}
395395

396+
/// Takes a 2+ dimensional vector as an input
397+
/// returns n vector values produced by n vector.extract operations.
398+
/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
399+
///
400+
/// %v : vector<nxaxb...>
401+
///
402+
/// will produce the following IR changes
403+
///
404+
/// %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
405+
/// %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
406+
/// ...
407+
/// %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
408+
///
409+
/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
410+
FailureOr<SmallVector<Value>>
411+
vector::unrollVectorValue(TypedValue<VectorType> vector,
412+
RewriterBase &rewriter) {
413+
SmallVector<Value> subvectors;
414+
VectorType ty = cast<VectorType>(vector.getType());
415+
Location loc = vector.getLoc();
416+
if (ty.getRank() < 2)
417+
return rewriter.notifyMatchFailure(loc, "already 1-D");
418+
419+
// Unrolling doesn't take vscale into account. Pattern is disabled for
420+
// vectors with leading scalable dim(s).
421+
if (ty.getScalableDims().front())
422+
return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
423+
424+
for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
425+
subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
426+
}
427+
428+
return subvectors;
429+
}
430+
396431
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
397432
vector::UnrollVectorOpFn unrollFn) {
398433
assert(op->getNumResults() == 1 && "expected single result");

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
17741774
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
17751775
return %0 : vector<2x1x2xf32>
17761776
}
1777+
1778+
// -----
1779+
1780+
//===----------------------------------------------------------------------===//
1781+
// vector.to_elements
1782+
//===----------------------------------------------------------------------===//
1783+
1784+
// CHECK-LABEL: func @to_elements_1d(
1785+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
1786+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
1787+
// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
1788+
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
1789+
// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
1790+
// CHECK: return %[[V0]], %[[V1]]
1791+
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
1792+
%0:2 = vector.to_elements %arg0 : vector<2xf32>
1793+
return %0#0, %0#1 : f32, f32
1794+
}
1795+
1796+
// -----
1797+
1798+
// NOTE: We unroll multi-dimensional to_elements ops with pattern
1799+
// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
1800+
1801+
// CHECK-LABEL: func @to_elements_2d(
1802+
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
1803+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
1804+
// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
1805+
// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
1806+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
1807+
// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
1808+
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
1809+
// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
1810+
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
1811+
// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
1812+
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
1813+
// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
1814+
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
1815+
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
1816+
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
1817+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
1818+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Skip the directory with input TD sequences.
2+
config.excludes = ["td"]

0 commit comments

Comments
 (0)