diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 45c54c7587c69..ad8255a95cb4e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6835,6 +6835,73 @@ class FoldTransposeShapeCast final : public OpRewritePattern { } }; +/// Folds transpose(from_elements(...)) into a new from_elements with permuted +/// operands matching the transposed shape. +/// +/// Example: +/// +/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 : +/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to +/// vector<3x2xi32> +/// +/// becomes -> +/// +/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 : +/// vector<3x2xi32> +/// +class FoldTransposeFromElements final : public OpRewritePattern { +public: + using Base::Base; + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + transposeOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + VectorType srcTy = fromElementsOp.getDest().getType(); + VectorType dstTy = transposeOp.getType(); + + ArrayRef permutation = transposeOp.getPermutation(); + int64_t rank = srcTy.getRank(); + + // Build inverse permutation to map destination indices back to source. + SmallVector inversePerm(rank, 0); + for (int64_t i = 0; i < rank; ++i) + inversePerm[permutation[i]] = i; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef dstShape = dstTy.getShape(); + SmallVector srcIdx(rank, 0); + SmallVector dstIdx(rank, 0); + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstStrides = computeStrides(dstShape); + + auto elementsOld = fromElementsOp.getElements(); + SmallVector elementsNew; + int64_t dstNumElements = dstTy.getNumElements(); + elementsNew.reserve(dstNumElements); + + // For each element in destination row-major order, pick the corresponding + // source element. + for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) { + // Pick the destination element index. + dstIdx = delinearize(linearIdx, dstStrides); + // Map the destination element index to the source element index. + for (int64_t j = 0; j < rank; ++j) + srcIdx[j] = dstIdx[inversePerm[j]]; + // Linearize the source element index. + int64_t srcLin = linearize(srcIdx, srcStrides); + // Add the source element to the new elements. + elementsNew.push_back(elementsOld[srcLin]); + } + + rewriter.replaceOpWithNewOp(transposeOp, dstTy, + elementsNew); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6935,7 +7002,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldTransposeSplat, FoldTransposeFromElements, + FoldTransposeBroadcast>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 59774f92cac36..084f49fca212f 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3530,6 +3530,62 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> { // ----- +// +--------------------------------------------------------------------------- +// Tests for FoldTransposeFromElements +// +--------------------------------------------------------------------------- + +// CHECK-LABEL: transpose_from_elements_1d +// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 +func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { + %v = vector.from_elements %el_0, %el_1 : vector<2xi32> + %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> + return %t : vector<2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// CHECK-LABEL: transpose_from_elements_2d +// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 +func.func @transpose_from_elements_2d( + %el_0_0: i32, %el_0_1: i32, %el_0_2: i32, + %el_1_0: i32, %el_1_1: i32, %el_1_2: i32 +) -> vector<3x2xi32> { + %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// CHECK-LABEL: transpose_from_elements_3d +// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 +func.func @transpose_from_elements_3d( + %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32, + %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32 +) -> vector<2x2x3xi32> { + %v = vector.from_elements + %el_0_0_0, %el_0_0_1, + %el_0_1_0, %el_0_1_1, + %el_0_2_0, %el_0_2_1, + %el_1_0_0, %el_1_0_1, + %el_1_1_0, %el_1_1_1, + %el_1_2_0, %el_1_2_1 + : vector<2x3x2xi32> + %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32> + return %t : vector<2x2x3xi32> + // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32> + // CHECK-NOT: vector.transpose + // CHECK: return %[[R]] +} + +// +--------------------------------------------------------------------------- +// End of Tests for FoldTransposeFromElements +// +--------------------------------------------------------------------------- + +// ----- + // Not a DenseElementsAttr, don't fold. // CHECK-LABEL: func @negative_insert_llvm_undef(