Skip to content

Commit 846f3f7

Browse files
committed
[mlir][linalg] convert arith ops to destination-passing-style.
Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass `linalg-convert-to-dps` has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize. e.g. `%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>` gets rewritten as: %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor<?xi32> %0 = tensor.empty(%dim) : tensor<?xf32> %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) { ^bb0(%in: i32, %out: f32): %2 = arith.uitofp %in : i32 to f32 linalg.yield %2 : f32 } -> tensor<?xf32>
1 parent 16494be commit 846f3f7

File tree

6 files changed

+240
-5
lines changed

6 files changed

+240
-5
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
171171
let dependentDialects = ["linalg::LinalgDialect"];
172172
}
173173

174+
def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
175+
let summary = "Convert ops to destination-passing-style";
176+
let description = [{
177+
Converts ops that operate on tensors but are not in
178+
destination passing style (DPS) to equivalent linalg
179+
generic which is in DPS. e.g.
180+
```mlir
181+
%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
182+
```
183+
gets rewritten as:
184+
```mlir
185+
%c0 = arith.constant 0 : index
186+
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
187+
%0 = tensor.empty(%dim) : tensor<?xf32>
188+
%1 = linalg.generic
189+
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
190+
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
191+
^bb0(%in: i32, %out: f32):
192+
%2 = arith.uitofp %in : i32 to f32
193+
linalg.yield %2 : f32
194+
} -> tensor<?xf32>
195+
```
196+
}];
197+
let dependentDialects = ["linalg::LinalgDialect"];
198+
}
199+
174200
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
175201
let summary = "Detensorize linalg ops";
176202
let dependentDialects = [];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
13771377
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
13781378
tensor::PadOp padOp);
13791379

1380+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1381+
arith::UIToFPOp op);
1382+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1383+
arith::SIToFPOp op);
1384+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1385+
arith::FPToUIOp op);
1386+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1387+
arith::FPToSIOp op);
1388+
1389+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1390+
arith::AddIOp op);
1391+
1392+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1393+
arith::AddFOp op);
1394+
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1395+
arith::DivFOp op);
1396+
13801397
/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
13811398
/// and linalg.matmul.
13821399
///

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ using namespace mlir::transform;
5858
/// pattern failed to apply. Extra arguments are forwarded to the pattern
5959
/// constructor.
6060
template <typename PatternTy, typename... Args>
61-
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
61+
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
6262
// Check if the given operation has the type expected by the pattern.
63-
using OpTy = typename llvm::function_traits<
64-
decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
63+
using OpTy = typename llvm::function_traits<decltype(
64+
&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
6565
auto op = dyn_cast<OpTy>(operation);
6666
if (!op)
6767
return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
26112611
rewriter.setInsertionPoint(target);
26122612
FailureOr<Operation *> maybeResult =
26132613
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
2614-
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2614+
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
2615+
arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
2616+
arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
26152617
[&rewriter](auto op) {
26162618
return rewriteInDestinationPassingStyle(rewriter, op);
26172619
});

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1818
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1919
#include "mlir/Dialect/Linalg/IR/Linalg.h"
20+
#include "mlir/Dialect/Linalg/Passes.h"
2021
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2122
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2223
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2527
#include "llvm/ADT/STLExtras.h"
2628

29+
namespace mlir {
30+
#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
31+
#include "mlir/Dialect/Linalg/Passes.h.inc"
32+
} // namespace mlir
33+
34+
#define DEBUG_TYPE "linalg-convert-to-dps"
35+
2736
using namespace mlir;
2837
using namespace mlir::tensor;
2938

@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
96105
OpBuilder::InsertionGuard g(rewriter);
97106
RankedTensorType resultType = padOp.getResultType();
98107

99-
// Examine the yielded value to decide if a linalg.generic is neede or a
108+
// Examine the yielded value to decide if a linalg.generic is needed or a
100109
// linalg.fill is sufficient.
101110
Value yieldedValue =
102111
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
603612
}
604613

605614
namespace {
615+
template <typename OpTy>
616+
FailureOr<Operation *>
617+
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
618+
// reject ops such as `arith.constant` and `arith.select`.
619+
auto numOperands = op->getNumOperands();
620+
if (numOperands == 0 || numOperands > 2)
621+
return failure();
622+
623+
// destination passing style rewrite is only for ops on tensor types.
624+
Type resultType = op->getResult(0).getType();
625+
auto tensorType = dyn_cast<RankedTensorType>(resultType);
626+
if (!tensorType)
627+
return failure();
628+
629+
auto loc = op.getLoc();
630+
OpBuilder::InsertionGuard g(rewriter);
631+
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
632+
633+
// Create tensor.empty.
634+
Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
635+
636+
// Create linalg.generic
637+
auto rank = tensorType.getRank();
638+
SmallVector<AffineMap> indexingMaps(numOperands + 1,
639+
rewriter.getMultiDimIdentityMap(rank));
640+
SmallVector<utils::IteratorType> iteratorTypes(rank,
641+
utils::IteratorType::parallel);
642+
auto genericOp = linalg::GenericOp::create(
643+
rewriter, loc, tensorType,
644+
op->getOperands(), // inputs
645+
ValueRange{empty}, // outputs
646+
indexingMaps, iteratorTypes,
647+
[&](OpBuilder &builder, Location loc, ValueRange args) {
648+
Value res;
649+
if (args.size() == 2) {
650+
res =
651+
builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
652+
.getResult();
653+
} else if (args.size() == 3) {
654+
res = builder.create<OpTy>(loc, args[2].getType(),
655+
ValueRange{args[0], args[1]});
656+
} else
657+
llvm_unreachable("did not expect ops other than nary and binary");
658+
linalg::YieldOp::create(builder, loc, res);
659+
});
660+
661+
rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
662+
rewriter.eraseOp(op);
663+
return genericOp.getOperation();
664+
}
606665

607666
template <typename OpTy>
608667
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
612671

613672
} // namespace
614673

674+
#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
675+
FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
676+
RewriterBase &rewriter, OPTY op) { \
677+
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
678+
}
679+
680+
STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
681+
STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
682+
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
683+
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
684+
685+
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
686+
687+
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
688+
STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
689+
615690
void linalg::populateConvertToDestinationStylePatterns(
616691
RewritePatternSet &patterns) {
617692
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
618693
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
619694
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
695+
696+
patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
697+
patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
698+
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
699+
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
700+
701+
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
702+
703+
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
704+
patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
620705
}
706+
707+
namespace {
708+
struct LinalgConvertToDPSPass
709+
: public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
710+
using impl::LinalgConvertToDPSPassBase<
711+
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
712+
713+
void runOnOperation() override;
714+
};
715+
716+
void LinalgConvertToDPSPass::runOnOperation() {
717+
718+
RewritePatternSet patterns(&getContext());
719+
linalg::populateConvertToDestinationStylePatterns(patterns);
720+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
721+
signalPassFailure();
722+
}
723+
} // namespace

mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
252252
transform.yield
253253
}
254254
}
255+
256+
// -----
257+
258+
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
259+
// CHECK-LABEL: func @arith_unary_op(
260+
// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
261+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
262+
// CHECK: %[[GENERIC:.+]] = linalg.generic
263+
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
264+
// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
265+
// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
266+
// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
267+
// CHECK: linalg.yield %[[z]] : f32
268+
// CHECK: return %[[GENERIC]] : tensor<64xf32>
269+
270+
func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
271+
%z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
272+
return %z : tensor<64xf32>
273+
}
274+
275+
module attributes {transform.with_named_sequence} {
276+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
277+
%0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
278+
: (!transform.any_op) -> !transform.any_op
279+
transform.structured.rewrite_in_destination_passing_style %0
280+
: (!transform.any_op) -> !transform.any_op
281+
transform.yield
282+
}
283+
}
284+
285+
// -----
286+
287+
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
288+
// CHECK-LABEL: func @arith_binop(
289+
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
290+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
291+
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
292+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
293+
// CHECK: %[[GENERIC:.+]] = linalg.generic
294+
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
295+
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
296+
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
297+
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
298+
// CHECK: linalg.yield %[[z]] : f32
299+
// CHECK: return %[[GENERIC]] : tensor<?xf32>
300+
301+
func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
302+
-> tensor<?xf32> {
303+
%z = arith.addf %x, %y : tensor<?xf32>
304+
return %z : tensor<?xf32>
305+
}
306+
307+
module attributes {transform.with_named_sequence} {
308+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
309+
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
310+
: (!transform.any_op) -> !transform.any_op
311+
transform.structured.rewrite_in_destination_passing_style %0
312+
: (!transform.any_op) -> !transform.any_op
313+
transform.yield
314+
}
315+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
2+
// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s
3+
4+
// CHECK-LABEL: func.func @lower_qcast_to_dps(
5+
// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
6+
// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
7+
// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
8+
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
9+
// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
10+
// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
11+
//
12+
// CHECK: %[[SITOFP:.+]] = linalg.generic
13+
// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
14+
// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
15+
//
16+
// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
17+
// CHECK: %{{.*}} = linalg.generic
18+
// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
19+
// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
20+
21+
22+
!qalias = !quant.uniform<i8:f32, 2.0:10>
23+
func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
24+
%0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
25+
return %0 : tensor<10x!qalias>
26+
}

0 commit comments

Comments
 (0)