diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 27b661781f10f..232ed60a0e8ce 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2292,9 +2292,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +/// Fold back-to-back broadcasts together. +struct FoldBroadcasts : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto defBroadcastOp = broadcastOp.getInput().getDefiningOp(); + if (!defBroadcastOp) + return failure(); + ArrayRef defDimensions = defBroadcastOp.getDimensions(); + ArrayRef dimensions = broadcastOp.getDimensions(); + SmallVector foldedDims(dimensions); + Value init = broadcastOp.getInit(); + int64_t initRank = cast(init.getType()).getRank(); + // Mapping from input dims to init dims. + SmallVector dimMap; + for (auto dim : llvm::seq(0, initRank)) { + if (!llvm::is_contained(dimensions, dim)) + dimMap.push_back(dim); + } + for (auto dim : defDimensions) + foldedDims.push_back(dimMap[dim]); + + llvm::sort(foldedDims); + rewriter.replaceOpWithNewOp( + broadcastOp, defBroadcastOp.getInput(), init, foldedDims); + return success(); + } +}; + void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add>(context); + results.add, FoldBroadcasts>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cbb56e4de884..fc88bb2dad0bd 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose