From c94bbb7846d0885a63d01b07ed7c8e362fd49689 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 3 Oct 2025 05:52:21 -0700 Subject: [PATCH 1/6] Added canonicalization (vector.from_elements + vector.transpose -> vector.transpose) Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 61 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 12 +++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e889302f..31246f5da49b1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2499,6 +2499,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, return DenseElementsAttr::get(destVecType, convertedElements); } + OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { if (auto res = foldFromElementsToElements(*this)) return res; @@ -6723,6 +6724,63 @@ class FoldTransposeShapeCast final : public OpRewritePattern { } }; +/// Folds transpose(from_elements(...)) into a new from_elements with permuted +/// operands matching the transposed shape. +class FoldTransposeFromElements final + : public OpRewritePattern { +public: + +using Base::Base; + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + transposeOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + VectorType srcTy = fromElementsOp.getDest().getType(); + VectorType dstTy = transposeOp.getType(); + + ArrayRef permutation = transposeOp.getPermutation(); + int64_t rank = srcTy.getRank(); + + // Build inverse permutation to map destination indices back to source. + SmallVector inversePerm(rank, 0); + for (int64_t i = 0; i < rank; ++i) + inversePerm[permutation[i]] = i; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef dstShape = dstTy.getShape(); + SmallVector srcIdx(rank, 0); + SmallVector dstIdx(rank, 0); + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstStrides = computeStrides(dstShape); + + auto elements = fromElementsOp.getElements(); + SmallVector newElements; + int64_t dstNumElements = dstTy.getNumElements(); + newElements.reserve(dstNumElements); + + // For each element in destination row-major order, pick the corresponding + // source element. + for (int64_t lin = 0; lin < dstNumElements; ++lin) { + // Pick the destination element index. + dstIdx = delinearize(lin, dstStrides); + // Map the destination element index to the source element index. + for (int64_t j = 0; j < rank; ++j) + srcIdx[j] = dstIdx[inversePerm[j]]; + // Linearize the source element index. + int64_t srcLin = linearize(srcIdx, srcStrides); + // Add the source element to the new elements. + newElements.push_back(elements[srcLin]); + } + + rewriter.replaceOpWithNewOp(transposeOp, dstTy, + newElements); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6823,7 +6881,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldTransposeSplat, FoldTransposeFromElements, + FoldTransposeBroadcast>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5448976f84760..5f34d144cd472 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- +// CHECK-LABEL: transpose_from_elements_2d +func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32, + %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> { + %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32> + // CHECK-NOT: vector.transpose +} + +// ----- + func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0 From 6bef6d259f8abf82d48092eae1404d6a2ebbfac7 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 3 Oct 2025 05:54:15 -0700 Subject: [PATCH 2/6] Formatted Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 31246f5da49b1..7f6313c11ea18 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2499,7 +2499,6 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, return DenseElementsAttr::get(destVecType, convertedElements); } - OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { if (auto res = foldFromElementsToElements(*this)) return res; @@ -6726,11 +6725,9 @@ class FoldTransposeShapeCast final : public OpRewritePattern { /// Folds transpose(from_elements(...)) into a new from_elements with permuted /// operands matching the transposed shape. -class FoldTransposeFromElements final - : public OpRewritePattern { +class FoldTransposeFromElements final : public OpRewritePattern { public: - -using Base::Base; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto fromElementsOp = @@ -6776,7 +6773,7 @@ using Base::Base; } rewriter.replaceOpWithNewOp(transposeOp, dstTy, - newElements); + newElements); return success(); } }; From 70d3d8f7ef66a595f8d4072af9cbfbafd2fe33eb Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 16 Oct 2025 03:58:20 -0700 Subject: [PATCH 3/6] Addressed comments: 1. Minor nitpicks in code formatting. 2. More lit tests, convering 1D, 2D, 3D cases. Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 ++++++------- mlir/test/Dialect/Vector/canonicalize.mlir | 33 ++++++++++++++++++---- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7f6313c11ea18..75e3a79b22aa9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6742,21 +6742,21 @@ class FoldTransposeFromElements final : public OpRewritePattern { int64_t rank = srcTy.getRank(); // Build inverse permutation to map destination indices back to source. - SmallVector inversePerm(rank, 0); + SmallVector inversePerm(rank, 0); for (int64_t i = 0; i < rank; ++i) inversePerm[permutation[i]] = i; ArrayRef srcShape = srcTy.getShape(); ArrayRef dstShape = dstTy.getShape(); - SmallVector srcIdx(rank, 0); - SmallVector dstIdx(rank, 0); - SmallVector srcStrides = computeStrides(srcShape); - SmallVector dstStrides = computeStrides(dstShape); + SmallVector srcIdx(rank, 0); + SmallVector dstIdx(rank, 0); + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstStrides = computeStrides(dstShape); - auto elements = fromElementsOp.getElements(); - SmallVector newElements; + auto elementsOld = fromElementsOp.getElements(); + SmallVector elementsNew; int64_t dstNumElements = dstTy.getNumElements(); - newElements.reserve(dstNumElements); + elementsNew.reserve(dstNumElements); // For each element in destination row-major order, pick the corresponding // source element. @@ -6769,11 +6769,11 @@ class FoldTransposeFromElements final : public OpRewritePattern { // Linearize the source element index. int64_t srcLin = linearize(srcIdx, srcStrides); // Add the source element to the new elements. - newElements.push_back(elements[srcLin]); + elementsNew.push_back(elementsOld[srcLin]); } rewriter.replaceOpWithNewOp(transposeOp, dstTy, - newElements); + elementsNew); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5f34d144cd472..d3b92ffb8cc88 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -308,16 +308,39 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- +// CHECK-LABEL: transpose_from_elements_1d +func.func @transpose_from_elements_1d(%arg0: i32, %arg1: i32) -> vector<2xi32> { + %v = vector.from_elements %arg0, %arg1 : vector<2xi32> + %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> + return %t : vector<2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg1 : vector<2xi32> + // CHECK-NOT: vector.transpose +} + // CHECK-LABEL: transpose_from_elements_2d -func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32, - %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> { - %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32> - %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> - return %t : vector<3x2xi32> +func.func @transpose_from_elements_2d( + %arg0: i32, %arg1: i32, %arg2: i32, + %arg3: i32, %arg4: i32, %arg5: i32 +) -> vector<3x2xi32> { + %arg6 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : vector<2x3xi32> + %arg7 = vector.transpose %arg6, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %arg7 : vector<3x2xi32> // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32> // CHECK-NOT: vector.transpose } +// CHECK-LABEL: transpose_from_elements_3d +func.func @transpose_from_elements_3d( + %arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, + %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32 +) -> vector<2x2x3xi32> { + %arg12 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : vector<2x3x2xi32> + %arg13 = vector.transpose %arg12, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> + return %arg13 : vector<2x2x3xi32> + // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg2, %arg4, %arg1, %arg3, %arg5, %arg6, %arg8, %arg10, %arg7, %arg9, %arg11 : vector<2x2x3xi32> + // CHECK-NOT: vector.transpose +} + // ----- func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { From 617267b40f96cd2f21064a1c56125a5afb7e5217 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 16 Oct 2025 08:06:55 -0700 Subject: [PATCH 4/6] Explainable arg names in lit test Signed-off-by: Keshav Vinayak Jha --- mlir/test/Dialect/Vector/canonicalize.mlir | 37 +++++++++++++--------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index d3b92ffb8cc88..e51eeb9fabbb8 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -309,35 +309,42 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- // CHECK-LABEL: transpose_from_elements_1d -func.func @transpose_from_elements_1d(%arg0: i32, %arg1: i32) -> vector<2xi32> { - %v = vector.from_elements %arg0, %arg1 : vector<2xi32> +func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { + %v = vector.from_elements %el_0, %el_1 : vector<2xi32> %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> return %t : vector<2xi32> - // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg1 : vector<2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32> // CHECK-NOT: vector.transpose } // CHECK-LABEL: transpose_from_elements_2d func.func @transpose_from_elements_2d( - %arg0: i32, %arg1: i32, %arg2: i32, - %arg3: i32, %arg4: i32, %arg5: i32 + %el_0_0: i32, %el_0_1: i32, %el_0_2: i32, + %el_1_0: i32, %el_1_1: i32, %el_1_2: i32 ) -> vector<3x2xi32> { - %arg6 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : vector<2x3xi32> - %arg7 = vector.transpose %arg6, [1, 0] : vector<2x3xi32> to vector<3x2xi32> - return %arg7 : vector<3x2xi32> - // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32> + %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32> // CHECK-NOT: vector.transpose } // CHECK-LABEL: transpose_from_elements_3d func.func @transpose_from_elements_3d( - %arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, - %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32 + %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32, + %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32 ) -> vector<2x2x3xi32> { - %arg12 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : vector<2x3x2xi32> - %arg13 = vector.transpose %arg12, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> - return %arg13 : vector<2x2x3xi32> - // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg2, %arg4, %arg1, %arg3, %arg5, %arg6, %arg8, %arg10, %arg7, %arg9, %arg11 : vector<2x2x3xi32> + %v = vector.from_elements + %el_0_0_0, %el_0_0_1, + %el_0_1_0, %el_0_1_1, + %el_0_2_0, %el_0_2_1, + %el_1_0_0, %el_1_0_1, + %el_1_1_0, %el_1_1_1, + %el_1_2_0, %el_1_2_1 + : vector<2x3x2xi32> + %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> + return %t : vector<2x2x3xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32> // CHECK-NOT: vector.transpose } From 2889f3d6795f562cd611b5f351ae4e8abc02c0fb Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 17 Oct 2025 06:11:59 -0700 Subject: [PATCH 5/6] Addressed Comments: 1. Changed variable name of linearIdx iterator. 2. Moved canonicalizer lit tests to other vector.from_elements tests. 3. Added blocked comments signaling beginning, end, and name of the pattern. Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 +- mlir/test/Dialect/Vector/canonicalize.mlir | 98 ++++++++++++---------- 2 files changed, 58 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 75e3a79b22aa9..7c588a435aa1a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6760,9 +6760,9 @@ class FoldTransposeFromElements final : public OpRewritePattern { // For each element in destination row-major order, pick the corresponding // source element. - for (int64_t lin = 0; lin < dstNumElements; ++lin) { + for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) { // Pick the destination element index. - dstIdx = delinearize(lin, dstStrides); + dstIdx = delinearize(linearIdx, dstStrides); // Map the destination element index to the source element index. for (int64_t j = 0; j < rank; ++j) srcIdx[j] = dstIdx[inversePerm[j]]; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index e51eeb9fabbb8..d5ae12f159a88 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -308,48 +308,6 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- -// CHECK-LABEL: transpose_from_elements_1d -func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { - %v = vector.from_elements %el_0, %el_1 : vector<2xi32> - %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> - return %t : vector<2xi32> - // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32> - // CHECK-NOT: vector.transpose -} - -// CHECK-LABEL: transpose_from_elements_2d -func.func @transpose_from_elements_2d( - %el_0_0: i32, %el_0_1: i32, %el_0_2: i32, - %el_1_0: i32, %el_1_1: i32, %el_1_2: i32 -) -> vector<3x2xi32> { - %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32> - %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> - return %t : vector<3x2xi32> - // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32> - // CHECK-NOT: vector.transpose -} - -// CHECK-LABEL: transpose_from_elements_3d -func.func @transpose_from_elements_3d( - %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32, - %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32 -) -> vector<2x2x3xi32> { - %v = vector.from_elements - %el_0_0_0, %el_0_0_1, - %el_0_1_0, %el_0_1_1, - %el_0_2_0, %el_0_2_1, - %el_1_0_0, %el_1_0_1, - %el_1_1_0, %el_1_1_1, - %el_1_2_0, %el_1_2_1 - : vector<2x3x2xi32> - %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> - return %t : vector<2x2x3xi32> - // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32> - // CHECK-NOT: vector.transpose -} - -// ----- - func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0 @@ -3527,6 +3485,62 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> { // ----- +// +--------------------------------------------------------------------------- +// Tests for FoldTransposeFromElements +// +--------------------------------------------------------------------------- + +// CHECK-LABEL: transpose_from_elements_1d +// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 +func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { + %v = vector.from_elements %el_0, %el_1 : vector<2xi32> + %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> + return %t : vector<2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// CHECK-LABEL: transpose_from_elements_2d +// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 +func.func @transpose_from_elements_2d( + %el_0_0: i32, %el_0_1: i32, %el_0_2: i32, + %el_1_0: i32, %el_1_1: i32, %el_1_2: i32 +) -> vector<3x2xi32> { + %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// CHECK-LABEL: transpose_from_elements_3d +// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 +func.func @transpose_from_elements_3d( + %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32, + %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32 +) -> vector<2x2x3xi32> { + %v = vector.from_elements + %el_0_0_0, %el_0_0_1, + %el_0_1_0, %el_0_1_1, + %el_0_2_0, %el_0_2_1, + %el_1_0_0, %el_1_0_1, + %el_1_1_0, %el_1_1_1, + %el_1_2_0, %el_1_2_1 + : vector<2x3x2xi32> + %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> + return %t : vector<2x2x3xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// +--------------------------------------------------------------------------- +// End of Tests for FoldTransposeFromElements +// +--------------------------------------------------------------------------- + +// ----- + // Not a DenseElementsAttr, don't fold. // CHECK-LABEL: func @negative_insert_llvm_undef( From 08a08023b3194252ede656b7539b7037fee4b973 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 21 Oct 2025 01:18:49 -0700 Subject: [PATCH 6/6] Added example for folder Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7c588a435aa1a..535192b4e10ad 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6725,6 +6725,18 @@ class FoldTransposeShapeCast final : public OpRewritePattern { /// Folds transpose(from_elements(...)) into a new from_elements with permuted /// operands matching the transposed shape. +/// +/// Example: +/// +/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 : +/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to +/// vector<3x2xi32> +/// +/// becomes -> +/// +/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 : +/// vector<3x2xi32> +/// class FoldTransposeFromElements final : public OpRewritePattern { public: using Base::Base;