diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 3c92b222e6bc8..a29ba47b28cde 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -450,6 +450,59 @@ class Transpose2DWithUnitDimToShapeCast } }; +// Suppose the permutation width is defined as the last index in the permutation +// array that is not equal to its index. This pattern is applied to transpose +// operations where the input vector has a shape with at most one non-unit +// dimension up to the permutation width. The pattern replaces the transpose +// operation with a shape cast operation. +// For example: +// %0 = vector.transpose %1, [1, 0, 2] : vector<4x1x1xi32> to vector<1x4x1xi32> +// is replaced by +// %0 = vector.shape_cast %1 : vector<4x1x1xi32> to vector<1x4x1xi32> +// given the permutation width is 2. +class TransposeWithUnitDimToShapeCast + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeWithUnitDimToShapeCast(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + Value input = op.getVector(); + VectorType inputType = op.getSourceVectorType(); + if (inputType.isScalable()) + return rewriter.notifyMatchFailure( + op, "This lowering does not support scalable vectors"); + VectorType resType = op.getResultVectorType(); + + ArrayRef transp = op.getPermutation(); + + // Get the permutation width. + int64_t permWidth = 1; + for (auto &&[idx, val] : llvm::enumerate(transp)) { + if (static_cast(idx) != val) + permWidth = idx + 1; + } + + // Check the no. of non unit dim in the input shape upto permutation width + // is not greater than one. + auto inputShape = inputType.getShape(); + + int64_t countNonUnitDims = 0; + for (int i = 0; i < permWidth; i++) { + if (inputShape[i] != 1) + countNonUnitDims++; + if (countNonUnitDims > 1) + return failure(); + } + rewriter.replaceOpWithNewOp(op, resType, input); + return success(); + } +}; + /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D @@ -522,8 +575,9 @@ class TransposeOp2DToShuffleLowering void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + patterns + .add( + patterns.getContext(), benefit); patterns.add( options, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 219a72df52a19..68e408488cf06 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -23,53 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { // CHECK-LABEL: func @transpose102_1x8x8xf32 func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK: %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } // CHECK-LABEL: func @transpose102_8x1x8xf32 func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> + // CHECK: %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } // CHECK-LABEL: func @transpose1023_1x1x8x8xf32( func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { - // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! - // CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> + // CHECK: return %arg0 : vector<1x1x8x8xf32> %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> return %0 : vector<1x1x8x8xf32> } @@ -386,6 +354,20 @@ func.func @transpose10_4x1xf32_scalable(%arg0: vector<4x[1]xf32>) -> vector<[1]x return %0 : vector<[1]x4xf32> } +// CHECK-LABEL: func @transpose_nd1 +func.func @transpose_nd1(%arg0: vector<1x2x1x16xf32>) -> vector<1x1x2x16xf32> { + // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x2x1x16xf32> to vector<1x1x2x16xf32> + %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x2x1x16xf32> to vector<1x1x2x16xf32> + return %0 : vector<1x1x2x16xf32> +} + +// CHECK-LABEL: func @transpose_nd2 +func.func @transpose_nd2(%arg0: vector<1x1x2x16xf32>) -> vector<1x2x1x16xf32> { + // CHECK-NEXT: vector.shape_cast %arg0 : vector<1x1x2x16xf32> to vector<1x2x1x16xf32> + %0 = vector.transpose %arg0, [0, 2, 1, 3] : vector<1x1x2x16xf32> to vector<1x2x1x16xf32> + return %0 : vector<1x2x1x16xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">