@@ -31,8 +31,6 @@ namespace mlir {
31
31
#include " mlir/Dialect/Linalg/Passes.h.inc"
32
32
} // namespace mlir
33
33
34
- #define DEBUG_TYPE " linalg-convert-to-dps"
35
-
36
34
using namespace mlir ;
37
35
using namespace mlir ::tensor;
38
36
@@ -612,10 +610,23 @@ Value linalg::bufferizeToAllocation(
612
610
}
613
611
614
612
namespace {
613
+ // / Rewrites an arith op operating on tensors, e.g.
614
+ // / `%z = arith.addf %x, %y : tensor<5xf32>`
615
+ // / into an equivalent linalg.generic in destination-passing-style.
616
+ // / ```mlir
617
+ // / %0 = tensor.empty() : tensor<5xf32>
618
+ // / %1 = linalg.generic ...
619
+ // / ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
620
+ // / outs(%0 : tensor<5xf32>) {
621
+ // / ^bb0(%in: f32, %in_0: f32, %out: f32):
622
+ // / %2 = arith.addf %in, %in_0 : f32
623
+ // / linalg.yield %2 : f32
624
+ // / } -> tensor<5xf32>
615
625
template <typename OpTy>
616
626
FailureOr<Operation *>
617
627
rewriteArithInDestinationPassingStyle (RewriterBase &rewriter, OpTy op) {
618
- // reject ops such as `arith.constant` and `arith.select`.
628
+ // Reject ops such as `arith.constant` and `arith.select`.
629
+ // constants don't need dps conversion and select is a a `todo`.
619
630
auto numOperands = op->getNumOperands ();
620
631
if (numOperands == 0 || numOperands > 2 )
621
632
return failure ();
@@ -630,8 +641,8 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
630
641
OpBuilder::InsertionGuard g (rewriter);
631
642
auto dynSizes = reifyOrComputeDynamicSizes (rewriter, op->getOperand (0 ));
632
643
633
- // Create tensor.empty.
634
- Value empty = tensor::EmptyOp::create (rewriter, loc, resultType, dynSizes);
644
+ // Create tensor.empty for `outs` of destination-passing-style .
645
+ Value outs = tensor::EmptyOp::create (rewriter, loc, resultType, dynSizes);
635
646
636
647
// Create linalg.generic
637
648
auto rank = tensorType.getRank ();
@@ -642,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
642
653
auto genericOp = linalg::GenericOp::create (
643
654
rewriter, loc, tensorType,
644
655
op->getOperands (), // inputs
645
- ValueRange{empty }, // outputs
656
+ ValueRange{outs }, // outputs
646
657
indexingMaps, iteratorTypes,
647
658
[&](OpBuilder &builder, Location loc, ValueRange args) {
648
659
Value res;
0 commit comments