Skip to content

Commit 4eb1a07

Browse files
yangtetrisnicolasvasilacheYang Bainewlingdcaballe
authored
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (#151175)
This patch introduces a new unrolling-based approach for lowering multi-dimensional `vector.from_elements` operations. **Implementation Details:** 1. **New Transform Pattern**: Added `UnrollFromElements` that unrolls a N-D(N>=2) from_elements op to a (N-1)-D from_elements op align the outermost dimension. 2. **Utility Functions**: Added `unrollVectorOp` to reuse the unroll algo of vector.gather for vector.from_elements. 3. **Integration**: Added the unrolling pattern to the convert-vector-to-llvm pass as a temporal transformation. 4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency in `VectorFromElementsLowering`. **Example:** ```mlir // unroll %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %vec_1d_0 = vector.from_elements %e0, %e1 : vector<2xf32> %vec_2d_0 = vector.insert %vec_1d_0, %poison_2d [0] : vector<2xf32> into vector<2x2xf32> %vec_1d_1 = vector.from_elements %e2, %e3 : vector<2xf32> %result = vector.insert %vec_1d_1, %vec_2d_0 [1] : vector<2xf32> into vector<2x2xf32> // convert-vector-to-llvm %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %poison_2d_cast = builtin.unrealized_conversion_cast %poison_2d : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> %poison_1d_0 = llvm.mlir.poison : vector<2xf32> %c0_0 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_0_0 = llvm.insertelement %e0, %poison_1d_0[%c0_0 : i64] : vector<2xf32> %c1_0 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_0_1 = llvm.insertelement %e1, %vec_1d_0_0[%c1_0 : i64] : vector<2xf32> %vec_2d_0 = llvm.insertvalue %vec_1d_0_1, %poison_2d_cast[0] : !llvm.array<2 x vector<2xf32>> %poison_1d_1 = llvm.mlir.poison : vector<2xf32> %c0_1 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_1_0 = llvm.insertelement %e2, %poison_1d_1[%c0_1 : i64] : vector<2xf32> %c1_1 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_1_1 = llvm.insertelement %e3, %vec_1d_1_0[%c1_1 : i64] : vector<2xf32> %vec_2d_1 = llvm.insertvalue %vec_1d_1_1, %vec_2d_0[1] : !llvm.array<2 x vector<2xf32>> %result = builtin.unrealized_conversion_cast %vec_2d_1 : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32> ``` --------- Co-authored-by: Nicolas Vasilache <[email protected]> Co-authored-by: Yang Bai <[email protected]> Co-authored-by: James Newling <[email protected]> Co-authored-by: Diego Caballero <[email protected]>
1 parent 8135b7c commit 4eb1a07

File tree

15 files changed

+261
-30
lines changed

15 files changed

+261
-30
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op<Transform_Dialect,
254254
let assemblyFormat = "attr-dict";
255255
}
256256

257+
def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
258+
"apply_patterns.vector.unroll_from_elements",
259+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
260+
let description = [{
261+
Indicates that vector from_elements operations should be unrolled
262+
along the outermost dimension.
263+
}];
264+
265+
let assemblyFormat = "attr-dict";
266+
}
267+
257268
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
258269
"apply_patterns.vector.lower_scan",
259270
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,14 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
303303
void populateVectorToFromElementsToShuffleTreePatterns(
304304
RewritePatternSet &patterns, PatternBenefit benefit = 1);
305305

306+
/// Populate the pattern set with the following patterns:
307+
///
308+
/// [UnrollFromElements]
309+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
310+
/// outermost dimension.
311+
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
312+
PatternBenefit benefit = 1);
313+
306314
/// Populate the pattern set with the following patterns:
307315
///
308316
/// [ContractionOpToMatmulOpLowering]

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1414
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/Dialect/UB/IR/UBOps.h"
1516
#include "mlir/Dialect/Utils/IndexingUtils.h"
1617
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1718
#include "mlir/IR/BuiltinAttributes.h"
@@ -238,6 +239,22 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
238239
/// static sizes in `shape`.
239240
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
240241
ArrayRef<int64_t> inputVectorSizes);
242+
243+
/// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
244+
/// This handles the common pattern of:
245+
/// 1. Check if already 1-D. If so, return failure.
246+
/// 2. Check for scalable dimensions. If so, return failure.
247+
/// 3. Create poison initialized result.
248+
/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
249+
/// create sub vectors.
250+
/// 5. Insert the sub vectors back into the final vector.
251+
/// 6. Replace the original op with the new result.
252+
using UnrollVectorOpFn =
253+
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
254+
255+
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
256+
UnrollVectorOpFn unrollFn);
257+
241258
} // namespace vector
242259

243260
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering
18911891
ConversionPatternRewriter &rewriter) const override {
18921892
Location loc = fromElementsOp.getLoc();
18931893
VectorType vectorType = fromElementsOp.getType();
1894-
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1895-
// Such ops should be handled in the same way as vector.insert.
1894+
// Only support 1-D vectors. Multi-dimensional vectors should have been
1895+
// transformed to 1-D vectors by the vector-to-vector transformations before
1896+
// this.
18961897
if (vectorType.getRank() > 1)
18971898
return rewriter.notifyMatchFailure(fromElementsOp,
18981899
"rank > 1 vectors are not supported");
18991900
Type llvmType = typeConverter->convertType(vectorType);
1901+
Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
19001902
Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1901-
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1902-
result = vector::InsertOp::create(rewriter, loc, val, result, idx);
1903+
for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1904+
auto constIdx =
1905+
LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1906+
result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1907+
val, constIdx);
1908+
}
19031909
rewriter.replaceOp(fromElementsOp, result);
19041910
return success();
19051911
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9494
populateVectorStepLoweringPatterns(patterns);
9595
populateVectorRankReducingFMAPattern(patterns);
9696
populateVectorGatherLoweringPatterns(patterns);
97+
populateVectorFromElementsLoweringPatterns(patterns);
9798
if (armI8MM) {
9899
if (armNeon)
99100
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
139139
vector::populateVectorGatherLoweringPatterns(patterns);
140140
}
141141

142+
void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
143+
RewritePatternSet &patterns) {
144+
vector::populateVectorFromElementsLoweringPatterns(patterns);
145+
}
146+
142147
void transform::ApplyLowerScanPatternsOp::populatePatterns(
143148
RewritePatternSet &patterns) {
144149
vector::populateVectorScanLoweringPatterns(patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
33
LowerVectorBitCast.cpp
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
6+
LowerVectorFromElements.cpp
67
LowerVectorGather.cpp
78
LowerVectorInterleave.cpp
89
LowerVectorMask.cpp
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.from_elements' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
17+
#define DEBUG_TYPE "lower-vector-from-elements"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
23+
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
24+
/// outermost dimension. For example:
25+
/// ```
26+
/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
27+
///
28+
/// ==>
29+
///
30+
/// %0 = ub.poison : vector<2x3xf32>
31+
/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
32+
/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
33+
/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
34+
/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
35+
/// ```
36+
///
37+
/// When applied exhaustively, this will produce a sequence of 1-d from_elements
38+
/// ops.
39+
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
40+
using OpRewritePattern::OpRewritePattern;
41+
42+
LogicalResult matchAndRewrite(vector::FromElementsOp op,
43+
PatternRewriter &rewriter) const override {
44+
ValueRange allElements = op.getElements();
45+
46+
auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
47+
VectorType subTy, int64_t index) {
48+
size_t subTyNumElements = subTy.getNumElements();
49+
assert((index + 1) * subTyNumElements <= allElements.size() &&
50+
"out of bounds");
51+
ValueRange subElements =
52+
allElements.slice(index * subTyNumElements, subTyNumElements);
53+
return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
54+
};
55+
56+
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
57+
}
58+
};
59+
60+
} // namespace
61+
62+
void mlir::vector::populateVectorFromElementsLoweringPatterns(
63+
RewritePatternSet &patterns, PatternBenefit benefit) {
64+
patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
65+
}

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,43 +54,26 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
5454

5555
LogicalResult matchAndRewrite(vector::GatherOp op,
5656
PatternRewriter &rewriter) const override {
57-
VectorType resultTy = op.getType();
58-
if (resultTy.getRank() < 2)
59-
return rewriter.notifyMatchFailure(op, "already 1-D");
60-
61-
// Unrolling doesn't take vscale into account. Pattern is disabled for
62-
// vectors with leading scalable dim(s).
63-
if (resultTy.getScalableDims().front())
64-
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
65-
66-
Location loc = op.getLoc();
6757
Value indexVec = op.getIndexVec();
6858
Value maskVec = op.getMask();
6959
Value passThruVec = op.getPassThru();
7060

71-
Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
72-
rewriter.getZeroAttr(resultTy));
73-
74-
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
75-
76-
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
77-
int64_t thisIdx[1] = {i};
61+
auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
62+
VectorType subTy, int64_t index) {
63+
int64_t thisIdx[1] = {index};
7864

7965
Value indexSubVec =
8066
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
8167
Value maskSubVec =
8268
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
8369
Value passThruSubVec =
8470
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
85-
Value subGather = vector::GatherOp::create(
86-
rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
87-
maskSubVec, passThruSubVec);
88-
result =
89-
vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
90-
}
71+
return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72+
op.getIndices(), indexSubVec, maskSubVec,
73+
passThruSubVec);
74+
};
9175

92-
rewriter.replaceOp(op, result);
93-
return success();
76+
return unrollVectorOp(op, rewriter, unrollGatherFn);
9477
}
9578
};
9679

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
392392
}
393393
return success();
394394
}
395+
396+
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
397+
vector::UnrollVectorOpFn unrollFn) {
398+
assert(op->getNumResults() == 1 && "expected single result");
399+
assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
400+
VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
401+
if (resultTy.getRank() < 2)
402+
return rewriter.notifyMatchFailure(op, "already 1-D");
403+
404+
// Unrolling doesn't take vscale into account. Pattern is disabled for
405+
// vectors with leading scalable dim(s).
406+
if (resultTy.getScalableDims().front())
407+
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
408+
409+
Location loc = op->getLoc();
410+
Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
411+
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
412+
413+
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
414+
Value subVector = unrollFn(rewriter, loc, subTy, i);
415+
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
416+
}
417+
418+
rewriter.replaceOp(op, result);
419+
return success();
420+
}

0 commit comments

Comments
 (0)