Skip to content

Conversation

@amd-eochoalo
Copy link
Contributor

@amd-eochoalo amd-eochoalo commented Sep 5, 2025

This PR adds support for unrolling vector.to_element's source operand.

It transforms

%0:8 = vector.to_elements %v : vector<2x2x2xf32>

to

%v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
%v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
%0:4 = vector.to_elements %v0 : vector<2x2xf32>
%1:4 = vector.to_elements %v1 : vector<2x2xf32>
// %0:8 = %0:4 - %1:4

This pattern will be applied until there are only 1-D vectors left.

hanhanW and others added 6 commits September 4, 2025 17:30
The revision adds a pattern that flattens 2 or more dimensional
`vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`.

It also adds the lowering pattern to ConvertVectorToLLVMPass and
complete the tests.

It recovers the e2e lowering breakage from llvm@b4c31dc on LLVM path.

Signed-off-by: hanhanW <[email protected]>
Extract n vector<axbx...> from vector<nxaxbx...>

This patch adds a utility function that will unroll vector values.
This is different from the current utility function that focuses
on unrolling vector operations.
@llvmbot
Copy link
Member

llvmbot commented Sep 5, 2025

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

This PR adds support for unrolling vector.to_element's source operand.

It transforms

%0:8 = vector.to_elements %v : vector&lt;2x2x2xf32&gt;

to

%v0 = vector.extract %v[0] : vector&lt;2x2xf32&gt; from vector&lt;2x2x2xf32&gt;
%v1 = vector.extract %v[1] : vector&lt;2x2xf32&gt; from vector&lt;2x2x2xf32&gt;
%0:4 = vector.to_elements %v0 : vector&lt;2x2xf32&gt;
%1:4 = vector.to_elements %v1 : vector&lt;2x2xf32&gt;
// %0:8 = %0:4 - %1:4

This pattern will be applied until there are only 1-D vectors left.


Full diff: https://github.com/llvm/llvm-project/pull/157142.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+12)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp (+83)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+21)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+42)
  • (added) mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir (+22)
  • (added) mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir (+24)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+48)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,18 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                SmallVector<Value> &values);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorToElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..b86e8b274770f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,83 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> vectors;
+    LogicalResult match =
+        mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+    if (failed(match)) {
+      return match;
+    }
+
+    // May be large vector.
+    std::vector<Value> results;
+    for (const auto &vector : vectors) {
+      // we need to replace the current result
+      auto subElements =
+          rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+      results.insert(results.end(), subElements.getResults().begin(),
+                     subElements.getResults().end());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "the rank is already less than or equal to 1");
+    if (vecType.getNumScalableDims() > 0)
+      return rewriter.notifyMatchFailure(
+          op, "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                                  vec1DType, op.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+                                                      shapeCast);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                        SmallVector<Value> &subvectors) {
+  assert(isa<VectorType>(vector.getType()) && "expected vector type");
+  VectorType ty = cast<VectorType>(vector.getType());
+  Location loc = vector.getLoc();
+  if (ty.getRank() < 2)
+    return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (ty.getScalableDims().front())
+    return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+  // We just need zero indices for the all dimensions except the leading one.
+  for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+    subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+  }
+
+  return success();
+}
+
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                                      vector::UnrollVectorOpFn unrollFn) {
   assert(op->getNumResults() == 1 && "expected single result");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
   %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
   return %0 : vector<2x1x2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..e302dbd174322
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK:         %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK:         return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,50 @@ struct TestUnrollVectorFromElements
   }
 };
 
+struct TestFlattenVectorToElements
+    : public PassWrapper<TestFlattenVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-flatten-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test flattening patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsFlatteningPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+struct TestUnrollVectorToElements
+    : public PassWrapper<TestUnrollVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-unroll-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1083,6 +1127,10 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestUnrollVectorToElements>();
+
+  PassRegistration<TestFlattenVectorToElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

@llvmbot
Copy link
Member

llvmbot commented Sep 5, 2025

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

This PR adds support for unrolling vector.to_element's source operand.

It transforms

%0:8 = vector.to_elements %v : vector&lt;2x2x2xf32&gt;

to

%v0 = vector.extract %v[0] : vector&lt;2x2xf32&gt; from vector&lt;2x2x2xf32&gt;
%v1 = vector.extract %v[1] : vector&lt;2x2xf32&gt; from vector&lt;2x2x2xf32&gt;
%0:4 = vector.to_elements %v0 : vector&lt;2x2xf32&gt;
%1:4 = vector.to_elements %v1 : vector&lt;2x2xf32&gt;
// %0:8 = %0:4 - %1:4

This pattern will be applied until there are only 1-D vectors left.


Full diff: https://github.com/llvm/llvm-project/pull/157142.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+12)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp (+83)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+21)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+42)
  • (added) mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir (+22)
  • (added) mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir (+24)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+48)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,18 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                SmallVector<Value> &values);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorToElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..b86e8b274770f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,83 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> vectors;
+    LogicalResult match =
+        mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+    if (failed(match)) {
+      return match;
+    }
+
+    // May be large vector.
+    std::vector<Value> results;
+    for (const auto &vector : vectors) {
+      // we need to replace the current result
+      auto subElements =
+          rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+      results.insert(results.end(), subElements.getResults().begin(),
+                     subElements.getResults().end());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "the rank is already less than or equal to 1");
+    if (vecType.getNumScalableDims() > 0)
+      return rewriter.notifyMatchFailure(
+          op, "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                                  vec1DType, op.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+                                                      shapeCast);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                        SmallVector<Value> &subvectors) {
+  assert(isa<VectorType>(vector.getType()) && "expected vector type");
+  VectorType ty = cast<VectorType>(vector.getType());
+  Location loc = vector.getLoc();
+  if (ty.getRank() < 2)
+    return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (ty.getScalableDims().front())
+    return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+  // We just need zero indices for the all dimensions except the leading one.
+  for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+    subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+  }
+
+  return success();
+}
+
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                                      vector::UnrollVectorOpFn unrollFn) {
   assert(op->getNumResults() == 1 && "expected single result");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
   %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
   return %0 : vector<2x1x2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..e302dbd174322
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK:         %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK:         return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,50 @@ struct TestUnrollVectorFromElements
   }
 };
 
+struct TestFlattenVectorToElements
+    : public PassWrapper<TestFlattenVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-flatten-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test flattening patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsFlatteningPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+struct TestUnrollVectorToElements
+    : public PassWrapper<TestUnrollVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-unroll-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1083,6 +1127,10 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestUnrollVectorToElements>();
+
+  PassRegistration<TestFlattenVectorToElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

@amd-eochoalo amd-eochoalo requested a review from newling September 5, 2025 17:52
@amd-eochoalo
Copy link
Contributor Author

Thanks @kuhar, I addressed your comments in a single commit c58b5c4

@amd-eochoalo amd-eochoalo requested a review from kuhar September 5, 2025 19:26
@amd-eochoalo amd-eochoalo requested a review from kuhar September 5, 2025 20:07
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's wait for a at least one more review

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please make separate PRs for unrolling and flattening?

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level question: Is the motivation for this PR that without it we currently fail to lower to rank-2+ from_elements ops from Vector to LLVM?

Parameters are now:
* using TypedValue<VectorType> instead of just Value
* using RewriterBase class.

Return types are:
* changed to FailureOr<SmallValue<Value>> instead of passing
  a Value as a parameter and returning Logical.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just some minor comments.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My final asks are minor, so approving as it, but please address them before landing.

Thanks for working on this! 🙏🏻

LGTM % outstanding asks

@amd-eochoalo amd-eochoalo merged commit 9d19250 into llvm:main Sep 11, 2025
9 checks passed
slackito pushed a commit that referenced this pull request Sep 11, 2025
…58158)

`mlir/test/Dialect/Vector/td/unroll-elements.mlir` is fed as a data
dependency
into`mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir` added in
[#157142](#157142). The Bazel
rule here automatically picks up all mlir files as tests, which leads to
`vector-to-elements-lowering` failing.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 11, 2025
…as data (#158158)

`mlir/test/Dialect/Vector/td/unroll-elements.mlir` is fed as a data
dependency
into`mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir` added in
[#157142](llvm/llvm-project#157142). The Bazel
rule here automatically picks up all mlir files as tests, which leads to
`vector-to-elements-lowering` failing.
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <[email protected]>
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <[email protected]>
kuhar pushed a commit to iree-org/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants