Skip to content

Commit 348e8aa

Browse files
committed
address review comments
1 parent 846f3f7 commit 348e8aa

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,17 @@ def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
182182
```
183183
gets rewritten as:
184184
```mlir
185-
%c0 = arith.constant 0 : index
186-
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
187-
%0 = tensor.empty(%dim) : tensor<?xf32>
188-
%1 = linalg.generic
189-
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
190-
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
185+
%c0 = arith.constant 0 : index
186+
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
187+
%0 = tensor.empty(%dim) : tensor<?xf32>
188+
%1 = linalg.generic
189+
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
190+
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
191191
^bb0(%in: i32, %out: f32):
192192
%2 = arith.uitofp %in : i32 to f32
193193
linalg.yield %2 : f32
194-
} -> tensor<?xf32>
195-
```
194+
} -> tensor<?xf32>
195+
```
196196
}];
197197
let dependentDialects = ["linalg::LinalgDialect"];
198198
}

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ using namespace mlir::transform;
5858
/// pattern failed to apply. Extra arguments are forwarded to the pattern
5959
/// constructor.
6060
template <typename PatternTy, typename... Args>
61-
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
61+
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
6262
// Check if the given operation has the type expected by the pattern.
63-
using OpTy = typename llvm::function_traits<decltype(
64-
&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
63+
using OpTy = typename llvm::function_traits<
64+
decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
6565
auto op = dyn_cast<OpTy>(operation);
6666
if (!op)
6767
return failure();

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ namespace mlir {
3131
#include "mlir/Dialect/Linalg/Passes.h.inc"
3232
} // namespace mlir
3333

34-
#define DEBUG_TYPE "linalg-convert-to-dps"
35-
3634
using namespace mlir;
3735
using namespace mlir::tensor;
3836

@@ -612,10 +610,23 @@ Value linalg::bufferizeToAllocation(
612610
}
613611

614612
namespace {
613+
/// Rewrites an arith op operating on tensors, e.g.
614+
/// `%z = arith.addf %x, %y : tensor<5xf32>`
615+
/// into an equivalent linalg.generic in destination-passing-style.
616+
/// ```mlir
617+
/// %0 = tensor.empty() : tensor<5xf32>
618+
/// %1 = linalg.generic ...
619+
/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
620+
/// outs(%0 : tensor<5xf32>) {
621+
/// ^bb0(%in: f32, %in_0: f32, %out: f32):
622+
/// %2 = arith.addf %in, %in_0 : f32
623+
/// linalg.yield %2 : f32
624+
/// } -> tensor<5xf32>
615625
template <typename OpTy>
616626
FailureOr<Operation *>
617627
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
618-
// reject ops such as `arith.constant` and `arith.select`.
628+
// Reject ops such as `arith.constant` and `arith.select`.
629+
// constants don't need dps conversion and select is a a `todo`.
619630
auto numOperands = op->getNumOperands();
620631
if (numOperands == 0 || numOperands > 2)
621632
return failure();
@@ -630,8 +641,8 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
630641
OpBuilder::InsertionGuard g(rewriter);
631642
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
632643

633-
// Create tensor.empty.
634-
Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
644+
// Create tensor.empty for `outs` of destination-passing-style.
645+
Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
635646

636647
// Create linalg.generic
637648
auto rank = tensorType.getRank();
@@ -642,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
642653
auto genericOp = linalg::GenericOp::create(
643654
rewriter, loc, tensorType,
644655
op->getOperands(), // inputs
645-
ValueRange{empty}, // outputs
656+
ValueRange{outs}, // outputs
646657
indexingMaps, iteratorTypes,
647658
[&](OpBuilder &builder, Location loc, ValueRange args) {
648659
Value res;

0 commit comments

Comments
 (0)