Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d661413
[mlir][vector] Add support for lowering n-D vector.to_elements op.
hanhanW Sep 5, 2025
edf4019
Add new populate patterns for flattening.
amd-eochoalo Sep 5, 2025
12fd6fc
[mlir][vector] Add function to unroll vectors.
amd-eochoalo Sep 5, 2025
0415e3b
[mlir][vector] Add vector.to_elements unrolling.
amd-eochoalo Sep 5, 2025
1b21117
Split tests for unrolling and flattening to elements
amd-eochoalo Sep 5, 2025
567af4b
Fix test
amd-eochoalo Sep 5, 2025
c58b5c4
Address review comments
amd-eochoalo Sep 5, 2025
871f5a5
Remove unnecessary comment
amd-eochoalo Sep 5, 2025
244eed5
Address review comments
amd-eochoalo Sep 5, 2025
9933387
Add comment
amd-eochoalo Sep 9, 2025
351086b
Disable and move vector flattening
amd-eochoalo Sep 9, 2025
9f7d15d
Re-enable and rename flattening to linearize
amd-eochoalo Sep 9, 2025
a8f9b9c
Remove LinearizeToElements
amd-eochoalo Sep 9, 2025
0568cba
Improve API of unrollVectorValue.
amd-eochoalo Sep 10, 2025
00cc94b
Documentation
amd-eochoalo Sep 10, 2025
6894af0
Add transform.apply_patterns.vector.unroll_to_elements
amd-eochoalo Sep 10, 2025
82004a2
Minor changes
amd-eochoalo Sep 10, 2025
cb6cf99
Remove comment and use inline storage for SmallVector
amd-eochoalo Sep 11, 2025
db83d2f
Add test with transform interpreter
amd-eochoalo Sep 11, 2025
ccec33c
Use transform dialect library file
amd-eochoalo Sep 11, 2025
ecca0d0
Merge branch 'main' into eochoa/2025-09-05/vector-to-elements-lowering
amd-eochoalo Sep 11, 2025
7e52d00
Update mlir/test/Dialect/Vector/lit.local.cfg
amd-eochoalo Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [UnrollToElements]
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,16 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);

/// Generic utility for mapping values of type vector<nxaxbx...>
/// to n values of type vector<axbx...>
/// Follows the following pattern:
/// 1. Check if already 1-D. If so, return failure.
/// 2. Check for scalable dimensions. If so, return failure.
/// 3. Returns the values of n vector.extract operations corresponding
/// to the outermost dimension.
LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
SmallVector<Value> &values);

} // namespace vector

/// Constructs a permutation map of invariant memref indices to vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
Expand Down
50 changes: 50 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.to_elements' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"

#define DEBUG_TYPE "lower-vector-to-elements"

using namespace mlir;

namespace {

struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> vectors;
if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) {
return failure();
}

// May be a large vector.
SmallVector<Value, 0> results;
for (const Value &vector : vectors) {
auto subElements =
rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
llvm::append_range(results, subElements.getResults());
}
rewriter.replaceOp(op, results);
return success();
}
};

} // namespace

void mlir::vector::populateVectorToElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
}
19 changes: 19 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,25 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}

LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
SmallVector<Value> &subvectors) {
VectorType ty = cast<VectorType>(vector.getType());
Location loc = vector.getLoc();
if (ty.getRank() < 2)
return rewriter.notifyMatchFailure(loc, "already 1-D");

// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
if (ty.getScalableDims().front())
return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");

for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
}

return success();
}

LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}

// -----

//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
// CHECK: return %[[V0]], %[[V1]]
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
%0:2 = vector.to_elements %arg0 : vector<2xf32>
return %0#0, %0#1 : f32, f32
}

// -----

// NOTE: We unroll multi-dimensional to_elements ops with pattern
// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.

// CHECK-LABEL: func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
// CHECK: return %[[RES]]#0, %[[RES]]#1
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
%0:2 = vector.to_elements %arg0 : vector<2xf32>
return %0#0, %0#1 : f32, f32
}

// -----

// CHECK-LABEL: func.func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
24 changes: 24 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
}
};

struct TestUnrollVectorToElements
: public PassWrapper<TestUnrollVectorToElements,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)

StringRef getArgument() const final {
return "test-unroll-vector-to-elements";
}
StringRef getDescription() const final {
return "Test unrolling patterns for to_elements ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, vector::VectorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {

PassRegistration<TestUnrollVectorFromElements>();

PassRegistration<TestUnrollVectorToElements>();

PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

PassRegistration<TestVectorEmulateMaskedLoadStore>();
Expand Down