Skip to content

Conversation

keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Oct 3, 2025

Description

Adds a new canonicalizer that folds vector.from_elements(vector.transpose)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the transpose op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Keshav Vinayak Jha (keshavvinayak01)

Changes

Description

Adds a new canonicalizer that folds vector.from_elements(vector.broadcast)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the broadcast op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.


Full diff: https://github.com/llvm/llvm-project/pull/161841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+57-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+  using Base::Base;
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElementsOp =
+        transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+
+    VectorType srcTy = fromElementsOp.getDest().getType();
+    VectorType dstTy = transposeOp.getType();
+
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    int64_t rank = srcTy.getRank();
+
+    // Build inverse permutation to map destination indices back to source.
+    SmallVector<int64_t, 4> inversePerm(rank, 0);
+    for (int64_t i = 0; i < rank; ++i)
+      inversePerm[permutation[i]] = i;
+
+    ArrayRef<int64_t> srcShape = srcTy.getShape();
+    ArrayRef<int64_t> dstShape = dstTy.getShape();
+    SmallVector<int64_t, 4> srcIdx(rank, 0);
+    SmallVector<int64_t, 4> dstIdx(rank, 0);
+    SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+    auto elements = fromElementsOp.getElements();
+    SmallVector<Value> newElements;
+    int64_t dstNumElements = dstTy.getNumElements();
+    newElements.reserve(dstNumElements);
+
+    // For each element in destination row-major order, pick the corresponding
+    // source element.
+    for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+      // Pick the destination element index.
+      dstIdx = delinearize(lin, dstStrides);
+      // Map the destination element index to the source element index.
+      for (int64_t j = 0; j < rank; ++j)
+        srcIdx[j] = dstIdx[inversePerm[j]];
+      // Linearize the source element index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      // Add the source element to the new elements.
+      newElements.push_back(elements[srcLin]);
+    }
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+                                                newElements);
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+              FoldTransposeSplat, FoldTransposeFromElements,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+                                      %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+  %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0

@keshavvinayak01 keshavvinayak01 changed the title [Vector] Added canonicalizer for folding from_elements + transpose [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose Oct 3, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


// For each element in destination row-major order, pick the corresponding
// source element.
for (int64_t lin = 0; lin < dstNumElements; ++lin) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does lin represent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So "lin" is short for "linear index" - it's a 1D index that represents the position of an element when the multi-dimensional vector is laid out in row-major order in memory. I felt it was a good iter name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find lin a bit too enigmatic. Why not linearIdx?

1. Minor nitpicks in code formatting.
2. More lit tests, convering 1D, 2D, 3D cases.

Signed-off-by: Keshav Vinayak Jha <[email protected]>

// For each element in destination row-major order, pick the corresponding
// source element.
for (int64_t lin = 0; lin < dstNumElements; ++lin) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find lin a bit too enigmatic. Why not linearIdx?

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Comment on lines 311 to 318
// CHECK-LABEL: transpose_from_elements_1d
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
return %t : vector<2xi32>
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32>
// CHECK-NOT: vector.transpose
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Variable R is defined, but never used. Please fix by adding return %[[R]].
  2. Variables EL_0, EL_1 should be defined near function signature and then re-used here.

Specifically:

Suggested change
// CHECK-LABEL: transpose_from_elements_1d
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
return %t : vector<2xi32>
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32>
// CHECK-NOT: vector.transpose
}
// CHECK-LABEL: transpose_from_elements_1d
// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
return %t : vector<2xi32>
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32>
// CHECK-NOT: vector.transpose
// CHECK: return %[[R]]
}

Similar comment for other tests. For more details, see e.g. https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-string-substitution-blocks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the newly added tests near other tests for vector.from_elements, e.g. https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/canonicalize.mlir

Also, please add block comments documenting what folder is tested. Examples:

// +---------------------------------------------------------------------------
// Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------

// +---------------------------------------------------------------------------
// End of Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------

1. Changed variable name of linearIdx iterator.
2. Moved canonicalizer lit tests to other vector.from_elements tests.
3. Added blocked comments signaling beginning, end, and name of the pattern.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
};

/// Folds transpose(from_elements(...)) into a new from_elements with permuted
/// operands matching the transposed shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a before and after IR example? That usually helps a lot with understanding.

@dcaballe
Copy link
Contributor

LGTM, thanks! (modulo pending comments)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants