Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d661413
[mlir][vector] Add support for lowering n-D vector.to_elements op.
hanhanW Sep 5, 2025
edf4019
Add new populate patterns for flattening.
amd-eochoalo Sep 5, 2025
12fd6fc
[mlir][vector] Add function to unroll vectors.
amd-eochoalo Sep 5, 2025
0415e3b
[mlir][vector] Add vector.to_elements unrolling.
amd-eochoalo Sep 5, 2025
1b21117
Split tests for unrolling and flattening to elements
amd-eochoalo Sep 5, 2025
567af4b
Fix test
amd-eochoalo Sep 5, 2025
c58b5c4
Address review comments
amd-eochoalo Sep 5, 2025
871f5a5
Remove unnecessary comment
amd-eochoalo Sep 5, 2025
244eed5
Address review comments
amd-eochoalo Sep 5, 2025
9933387
Add comment
amd-eochoalo Sep 9, 2025
351086b
Disable and move vector flattening
amd-eochoalo Sep 9, 2025
9f7d15d
Re-enable and rename flattening to linearize
amd-eochoalo Sep 9, 2025
a8f9b9c
Remove LinearizeToElements
amd-eochoalo Sep 9, 2025
0568cba
Improve API of unrollVectorValue.
amd-eochoalo Sep 10, 2025
00cc94b
Documentation
amd-eochoalo Sep 10, 2025
6894af0
Add transform.apply_patterns.vector.unroll_to_elements
amd-eochoalo Sep 10, 2025
82004a2
Minor changes
amd-eochoalo Sep 10, 2025
cb6cf99
Remove comment and use inline storage for SmallVector
amd-eochoalo Sep 11, 2025
db83d2f
Add test with transform interpreter
amd-eochoalo Sep 11, 2025
ccec33c
Use transform dialect library file
amd-eochoalo Sep 11, 2025
ecca0d0
Merge branch 'main' into eochoa/2025-09-05/vector-to-elements-lowering
amd-eochoalo Sep 11, 2025
7e52d00
Update mlir/test/Dialect/Vector/lit.local.cfg
amd-eochoalo Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.unroll_to_elements",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector to_elements operations should be unrolled
along the outermost dimension.
}];

let assemblyFormat = "attr-dict";
}

def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_scan",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [UnrollToElements]
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);

/// Generic utility for unrolling values of type vector<NxAxBx...>
/// to N values of type vector<AxBx...> using vector.extract. If the input
/// is rank-1 or has leading scalable dimension, failure is returned.
FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
RewriterBase &);

} // namespace vector

/// Constructs a permutation map of invariant memref indices to vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
vector::populateVectorFromElementsLoweringPatterns(patterns);
}

void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorToElementsLoweringPatterns(patterns);
}

void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.to_elements' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"

#define DEBUG_TYPE "lower-vector-to-elements"

using namespace mlir;

namespace {

struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {

TypedValue<VectorType> source = op.getSource();
FailureOr<SmallVector<Value>> result =
vector::unrollVectorValue(source, rewriter);
if (failed(result)) {
return failure();
}
SmallVector<Value> vectors = *result;

// May be a large vector.
SmallVector<Value, 0> results;
for (const Value &vector : vectors) {
auto subElements =
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
llvm::append_range(results, subElements.getResults());
}
rewriter.replaceOp(op, results);
return success();
}
};

} // namespace

void mlir::vector::populateVectorToElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
}
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}

/// Takes a 2+ dimensional vector as an input
/// returns n vector values produced by n vector.extract operations.
/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
///
/// %v : vector<nxaxb...>
///
/// will produce the following IR changes
///
/// %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
/// %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
/// ...
/// %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
///
/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
FailureOr<SmallVector<Value>>
vector::unrollVectorValue(TypedValue<VectorType> vector,
RewriterBase &rewriter) {
SmallVector<Value> subvectors;
VectorType ty = cast<VectorType>(vector.getType());
Location loc = vector.getLoc();
if (ty.getRank() < 2)
return rewriter.notifyMatchFailure(loc, "already 1-D");

// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
if (ty.getScalableDims().front())
return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");

for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
}

return subvectors;
}

LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}

// -----

//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
// CHECK: return %[[V0]], %[[V1]]
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
%0:2 = vector.to_elements %arg0 : vector<2xf32>
return %0#0, %0#1 : f32, f32
}

// -----

// NOTE: We unroll multi-dimensional to_elements ops with pattern
// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.

// CHECK-LABEL: func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
// CHECK: return %[[RES]]#0, %[[RES]]#1
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
%0:2 = vector.to_elements %arg0 : vector<2xf32>
return %0#0, %0#1 : f32, f32
}

// -----

// CHECK-LABEL: func.func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
24 changes: 24 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
}
};

struct TestUnrollVectorToElements
: public PassWrapper<TestUnrollVectorToElements,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)

StringRef getArgument() const final {
return "test-unroll-vector-to-elements";
}
StringRef getDescription() const final {
return "Test unrolling patterns for to_elements ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, vector::VectorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {

PassRegistration<TestUnrollVectorFromElements>();

PassRegistration<TestUnrollVectorToElements>();

PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

PassRegistration<TestVectorEmulateMaskedLoadStore>();
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/python/dialects/transform_vector_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def non_configurable_patterns():
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_from_elements
vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_to_elements
vector.ApplyUnrollToElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
Expand Down
Loading