@@ -31,8 +31,6 @@ namespace mlir {
3131#include " mlir/Dialect/Linalg/Passes.h.inc"
3232} // namespace mlir
3333
34- #define DEBUG_TYPE " linalg-convert-to-dps"
35-
3634using namespace mlir ;
3735using namespace mlir ::tensor;
3836
@@ -612,10 +610,23 @@ Value linalg::bufferizeToAllocation(
612610}
613611
614612namespace {
613+ // / Rewrites an arith op operating on tensors, e.g.
614+ // / `%z = arith.addf %x, %y : tensor<5xf32>`
615+ // / into an equivalent linalg.generic in destination-passing-style.
616+ // / ```mlir
617+ // / %0 = tensor.empty() : tensor<5xf32>
618+ // / %1 = linalg.generic ...
619+ // / ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
620+ // / outs(%0 : tensor<5xf32>) {
621+ // / ^bb0(%in: f32, %in_0: f32, %out: f32):
622+ // / %2 = arith.addf %in, %in_0 : f32
623+ // / linalg.yield %2 : f32
624+ // / } -> tensor<5xf32>
615625template <typename OpTy>
616626FailureOr<Operation *>
617627rewriteArithInDestinationPassingStyle (RewriterBase &rewriter, OpTy op) {
618- // reject ops such as `arith.constant` and `arith.select`.
628+ // Reject ops such as `arith.constant` and `arith.select`.
629+ // constants don't need dps conversion and select is a a `todo`.
619630 auto numOperands = op->getNumOperands ();
620631 if (numOperands == 0 || numOperands > 2 )
621632 return failure ();
@@ -630,8 +641,8 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
630641 OpBuilder::InsertionGuard g (rewriter);
631642 auto dynSizes = reifyOrComputeDynamicSizes (rewriter, op->getOperand (0 ));
632643
633- // Create tensor.empty.
634- Value empty = tensor::EmptyOp::create (rewriter, loc, resultType, dynSizes);
644+ // Create tensor.empty for `outs` of destination-passing-style .
645+ Value outs = tensor::EmptyOp::create (rewriter, loc, resultType, dynSizes);
635646
636647 // Create linalg.generic
637648 auto rank = tensorType.getRank ();
@@ -642,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
642653 auto genericOp = linalg::GenericOp::create (
643654 rewriter, loc, tensorType,
644655 op->getOperands (), // inputs
645- ValueRange{empty }, // outputs
656+ ValueRange{outs }, // outputs
646657 indexingMaps, iteratorTypes,
647658 [&](OpBuilder &builder, Location loc, ValueRange args) {
648659 Value res;
0 commit comments