Skip to content

Conversation

keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Oct 3, 2025

Description

Adds a new canonicalizer that folds vector.from_elements(vector.transpose)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the transpose op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Keshav Vinayak Jha (keshavvinayak01)

Changes

Description

Adds a new canonicalizer that folds vector.from_elements(vector.broadcast)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the broadcast op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.


Full diff: https://github.com/llvm/llvm-project/pull/161841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+57-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+  using Base::Base;
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElementsOp =
+        transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+
+    VectorType srcTy = fromElementsOp.getDest().getType();
+    VectorType dstTy = transposeOp.getType();
+
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    int64_t rank = srcTy.getRank();
+
+    // Build inverse permutation to map destination indices back to source.
+    SmallVector<int64_t, 4> inversePerm(rank, 0);
+    for (int64_t i = 0; i < rank; ++i)
+      inversePerm[permutation[i]] = i;
+
+    ArrayRef<int64_t> srcShape = srcTy.getShape();
+    ArrayRef<int64_t> dstShape = dstTy.getShape();
+    SmallVector<int64_t, 4> srcIdx(rank, 0);
+    SmallVector<int64_t, 4> dstIdx(rank, 0);
+    SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+    auto elements = fromElementsOp.getElements();
+    SmallVector<Value> newElements;
+    int64_t dstNumElements = dstTy.getNumElements();
+    newElements.reserve(dstNumElements);
+
+    // For each element in destination row-major order, pick the corresponding
+    // source element.
+    for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+      // Pick the destination element index.
+      dstIdx = delinearize(lin, dstStrides);
+      // Map the destination element index to the source element index.
+      for (int64_t j = 0; j < rank; ++j)
+        srcIdx[j] = dstIdx[inversePerm[j]];
+      // Linearize the source element index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      // Add the source element to the new elements.
+      newElements.push_back(elements[srcLin]);
+    }
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+                                                newElements);
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+              FoldTransposeSplat, FoldTransposeFromElements,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+                                      %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+  %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0

@keshavvinayak01 keshavvinayak01 changed the title [Vector] Added canonicalizer for folding from_elements + transpose [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose Oct 3, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


// For each element in destination row-major order, pick the corresponding
// source element.
for (int64_t lin = 0; lin < dstNumElements; ++lin) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does lin represent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So "lin" is short for "linear index" - it's a 1D index that represents the position of an element when the multi-dimensional vector is laid out in row-major order in memory. I felt it was a good iter name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find lin a bit too enigmatic. Why not linearIdx?

1. Minor nitpicks in code formatting.
2. More lit tests, convering 1D, 2D, 3D cases.

Signed-off-by: Keshav Vinayak Jha <[email protected]>

// For each element in destination row-major order, pick the corresponding
// source element.
for (int64_t lin = 0; lin < dstNumElements; ++lin) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find lin a bit too enigmatic. Why not linearIdx?

Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Changed variable name of linearIdx iterator.
2. Moved canonicalizer lit tests to other vector.from_elements tests.
3. Added blocked comments signaling beginning, end, and name of the pattern.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@dcaballe
Copy link
Contributor

LGTM, thanks! (modulo pending comments)

@keshavvinayak01 keshavvinayak01 merged commit 00092f9 into llvm:main Oct 21, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants