diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 98d98f067de14..59a7ea761a5ce 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5776,11 +5776,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final } }; -/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. -/// This only applies when the shape of the broadcast source -/// 1. is a suffix of the shape of the result (i.e. when broadcast without -/// reshape is expressive enough to capture the result in a single op), or -/// 2. has the same element count as the shape cast result. +/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either +/// i) Y = ShapeCast(X), or +/// ii) Y = Broadcast(X) +/// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -5792,26 +5791,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { if (!broadcastOp) return failure(); - ArrayRef broadcastSourceShape; - if (auto srcType = dyn_cast(broadcastOp.getSourceType())) - broadcastSourceShape = srcType.getShape(); - ArrayRef shapeCastTargetShape = - shapeCastOp.getResultVectorType().getShape(); - - // If `broadcastSourceShape` is a suffix of the result, we can just replace - // with a broadcast to the final shape. - if (broadcastSourceShape == - shapeCastTargetShape.take_back(broadcastSourceShape.size())) { - rewriter.replaceOpWithNewOp( - shapeCastOp, shapeCastOp.getResultVectorType(), - broadcastOp.getSource()); - return success(); - } - - // Otherwise, if the final result has the same element count, we can replace - // with a shape cast. - if (auto srcType = dyn_cast(broadcastOp.getSourceType())) { - if (srcType.getNumElements() == + auto srcVectorType = dyn_cast(broadcastOp.getSourceType()); + bool srcIsScalar = !srcVectorType; + + // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). + // Example: + // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> + // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> + // to + // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> + if (srcVectorType) { + if (srcVectorType.getNumElements() == shapeCastOp.getResultVectorType().getNumElements()) { rewriter.replaceOpWithNewOp( shapeCastOp, shapeCastOp.getResultVectorType(), @@ -5820,6 +5810,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { } } + // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X) + // Example + // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32> + // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32> + // to + // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32> + VectorType dstVectorType = shapeCastOp.getResultVectorType(); + if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) == + BroadcastableToResult::Success) { + rewriter.replaceOpWithNewOp( + shapeCastOp, dstVectorType, broadcastOp.getSource()); + return success(); + } return failure(); } }; @@ -6079,7 +6082,7 @@ class TransposeFolder final : public OpRewritePattern { } }; -// Folds transpose(broadcast()) into brodcast(). +// Folds transpose(broadcast()) into broadcast(). struct FoldTransposedScalarBroadcast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b7db8ec834be7..aafa07dc19c38 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1006,13 +1006,48 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) - // ----- -// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones +// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8> +// CHECK-NOT: vector.shape_cast +func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8> + %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8> + return %1 : vector<1x1x6x1x4xi8> +} + +// ----- + +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar +// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32> +// CHECK-NOT: vector.shape_cast +func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<12xf32> + %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32> + return %1 : vector<3x4x1xf32> +} + +// ----- + +// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32> +func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32> + %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32> + return %1 : vector<1x2x1xf32> +} + +// ----- + +// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible // CHECK-NOT: vector.broadcast -// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32> -func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> { - %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32> - %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32> - return %1 : vector<1x12xf32> +// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32> +func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32> + %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32> + return %1 : vector<1x1xf32> } // -----