-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] convert arith ops to destination-passing-style. #157854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,13 +17,20 @@ | |||||||||||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||||||||||||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/Passes.h" | ||||||||||||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||||||||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||||||||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||||||||||||||||
#include "mlir/IR/Matchers.h" | ||||||||||||||||
#include "mlir/IR/PatternMatch.h" | ||||||||||||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||||||||||||||
#include "llvm/ADT/STLExtras.h" | ||||||||||||||||
|
||||||||||||||||
namespace mlir { | ||||||||||||||||
#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS | ||||||||||||||||
#include "mlir/Dialect/Linalg/Passes.h.inc" | ||||||||||||||||
} // namespace mlir | ||||||||||||||||
|
||||||||||||||||
using namespace mlir; | ||||||||||||||||
using namespace mlir::tensor; | ||||||||||||||||
|
||||||||||||||||
|
@@ -96,7 +103,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, | |||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||
RankedTensorType resultType = padOp.getResultType(); | ||||||||||||||||
|
||||||||||||||||
// Examine the yielded value to decide if a linalg.generic is neede or a | ||||||||||||||||
// Examine the yielded value to decide if a linalg.generic is needed or a | ||||||||||||||||
// linalg.fill is sufficient. | ||||||||||||||||
Value yieldedValue = | ||||||||||||||||
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue(); | ||||||||||||||||
|
@@ -603,6 +610,69 @@ Value linalg::bufferizeToAllocation( | |||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
/// Rewrites an arith op operating on tensors, e.g. | ||||||||||||||||
/// `%z = arith.addf %x, %y : tensor<5xf32>` | ||||||||||||||||
/// into an equivalent linalg.generic in destination-passing-style. | ||||||||||||||||
/// ```mlir | ||||||||||||||||
/// %0 = tensor.empty() : tensor<5xf32> | ||||||||||||||||
/// %1 = linalg.generic ... | ||||||||||||||||
/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>) | ||||||||||||||||
/// outs(%0 : tensor<5xf32>) { | ||||||||||||||||
/// ^bb0(%in: f32, %in_0: f32, %out: f32): | ||||||||||||||||
/// %2 = arith.addf %in, %in_0 : f32 | ||||||||||||||||
/// linalg.yield %2 : f32 | ||||||||||||||||
/// } -> tensor<5xf32> | ||||||||||||||||
template <typename OpTy> | ||||||||||||||||
FailureOr<Operation *> | ||||||||||||||||
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) { | ||||||||||||||||
// Reject ops such as `arith.constant` and `arith.select`. | ||||||||||||||||
// constants don't need dps conversion and select is a a `todo`. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
auto numOperands = op->getNumOperands(); | ||||||||||||||||
if (numOperands == 0 || numOperands > 2) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only unary and binary we care about. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why?
This is what the code tells me, yes. But it doesn't say why. Also, if that's the case, then this would be clearer to me: if (numOperands != 1 && numOperands != 2) EDIT Sorry, posted my comment before noticing that this comment has been updated.
Why not? This is assuming that the only purpose of this code is to help with bufferization. That is fine with, but make it clear. Otherwise, "constants don't need dps conversion" sounds very arbitrary (missing "why"). |
||||||||||||||||
return failure(); | ||||||||||||||||
|
||||||||||||||||
// destination passing style rewrite is only for ops on tensor types. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
Type resultType = op->getResult(0).getType(); | ||||||||||||||||
auto tensorType = dyn_cast<RankedTensorType>(resultType); | ||||||||||||||||
if (!tensorType) | ||||||||||||||||
return failure(); | ||||||||||||||||
|
||||||||||||||||
auto loc = op.getLoc(); | ||||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0)); | ||||||||||||||||
|
||||||||||||||||
// Create tensor.empty for `outs` of destination-passing-style. | ||||||||||||||||
Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes); | ||||||||||||||||
|
||||||||||||||||
// Create linalg.generic | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
auto rank = tensorType.getRank(); | ||||||||||||||||
SmallVector<AffineMap> indexingMaps(numOperands + 1, | ||||||||||||||||
rewriter.getMultiDimIdentityMap(rank)); | ||||||||||||||||
SmallVector<utils::IteratorType> iteratorTypes(rank, | ||||||||||||||||
utils::IteratorType::parallel); | ||||||||||||||||
auto genericOp = linalg::GenericOp::create( | ||||||||||||||||
rewriter, loc, tensorType, | ||||||||||||||||
op->getOperands(), // inputs | ||||||||||||||||
ValueRange{outs}, // outputs | ||||||||||||||||
indexingMaps, iteratorTypes, | ||||||||||||||||
[&](OpBuilder &builder, Location loc, ValueRange args) { | ||||||||||||||||
Value res; | ||||||||||||||||
if (args.size() == 2) { | ||||||||||||||||
res = | ||||||||||||||||
builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]}) | ||||||||||||||||
.getResult(); | ||||||||||||||||
} else if (args.size() == 3) { | ||||||||||||||||
res = builder.create<OpTy>(loc, args[2].getType(), | ||||||||||||||||
ValueRange{args[0], args[1]}); | ||||||||||||||||
} else | ||||||||||||||||
llvm_unreachable("did not expect ops other than nary and binary"); | ||||||||||||||||
linalg::YieldOp::create(builder, loc, res); | ||||||||||||||||
Comment on lines
+692
to
+694
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think the convention is that if one branch if/else branch uses braces, all the other ones should too
Suggested change
|
||||||||||||||||
}); | ||||||||||||||||
|
||||||||||||||||
rewriter.replaceAllUsesWith(op, genericOp.getResult(0)); | ||||||||||||||||
rewriter.eraseOp(op); | ||||||||||||||||
return genericOp.getOperation(); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
template <typename OpTy> | ||||||||||||||||
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, | ||||||||||||||||
|
@@ -612,9 +682,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, | |||||||||||||||
|
||||||||||||||||
} // namespace | ||||||||||||||||
|
||||||||||||||||
#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \ | ||||||||||||||||
FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \ | ||||||||||||||||
RewriterBase &rewriter, OPTY op) { \ | ||||||||||||||||
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \ | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp) | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp) | ||||||||||||||||
|
||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp) | ||||||||||||||||
STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can undef this macro here |
||||||||||||||||
|
||||||||||||||||
void linalg::populateConvertToDestinationStylePatterns( | ||||||||||||||||
RewritePatternSet &patterns) { | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>); | ||||||||||||||||
|
||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>); | ||||||||||||||||
patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
struct LinalgConvertToDPSPass | ||||||||||||||||
: public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> { | ||||||||||||||||
using impl::LinalgConvertToDPSPassBase< | ||||||||||||||||
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase; | ||||||||||||||||
Comment on lines
+746
to
+747
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
void runOnOperation() override; | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
void LinalgConvertToDPSPass::runOnOperation() { | ||||||||||||||||
|
||||||||||||||||
RewritePatternSet patterns(&getContext()); | ||||||||||||||||
linalg::populateConvertToDestinationStylePatterns(patterns); | ||||||||||||||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would the walk rewrite driver work here? I think this will never have to rewrite the same op twice or visit newly created ops. https://mlir.llvm.org/docs/PatternRewriter/#walk-pattern-rewrite-driver |
||||||||||||||||
signalPassFailure(); | ||||||||||||||||
} | ||||||||||||||||
} // namespace |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest creating a new test file with Arith Ops. I see two reasons:
If you prefer to keep everything in one file, could you add a big comment separating
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} { | |||||
transform.yield | ||||||
} | ||||||
} | ||||||
|
||||||
// ----- | ||||||
|
||||||
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> | ||||||
// CHECK-LABEL: func @arith_unary_op( | ||||||
// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> { | ||||||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32> | ||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||||||
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]} | ||||||
// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) { | ||||||
// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32): | ||||||
// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32 | ||||||
// CHECK: linalg.yield %[[z]] : f32 | ||||||
// CHECK: return %[[GENERIC]] : tensor<64xf32> | ||||||
|
||||||
func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> { | ||||||
%z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32> | ||||||
return %z : tensor<64xf32> | ||||||
} | ||||||
|
||||||
module attributes {transform.with_named_sequence} { | ||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||||||
%0 = transform.structured.match ops{["arith.uitofp"]} in %arg1 | ||||||
: (!transform.any_op) -> !transform.any_op | ||||||
transform.structured.rewrite_in_destination_passing_style %0 | ||||||
: (!transform.any_op) -> !transform.any_op | ||||||
transform.yield | ||||||
} | ||||||
} | ||||||
|
||||||
// ----- | ||||||
|
||||||
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> | ||||||
// CHECK-LABEL: func @arith_binop( | ||||||
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32> | ||||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index | ||||||
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32> | ||||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32> | ||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||||||
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]} | ||||||
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) { | ||||||
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32): | ||||||
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32 | ||||||
// CHECK: linalg.yield %[[z]] : f32 | ||||||
// CHECK: return %[[GENERIC]] : tensor<?xf32> | ||||||
|
||||||
func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use consistent naming.
Suggested change
|
||||||
-> tensor<?xf32> { | ||||||
%z = arith.addf %x, %y : tensor<?xf32> | ||||||
return %z : tensor<?xf32> | ||||||
} | ||||||
|
||||||
module attributes {transform.with_named_sequence} { | ||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||||||
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 | ||||||
: (!transform.any_op) -> !transform.any_op | ||||||
transform.structured.rewrite_in_destination_passing_style %0 | ||||||
: (!transform.any_op) -> !transform.any_op | ||||||
transform.yield | ||||||
} | ||||||
} | ||||||
kuhar marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \ | ||
// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @lower_qcast_to_dps( | ||
// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>> | ||
// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8> | ||
// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32> | ||
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32> | ||
// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>) | ||
// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32> | ||
// | ||
// CHECK: %[[SITOFP:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>) | ||
// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32 | ||
// | ||
// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>) | ||
// CHECK: %{{.*}} = linalg.generic | ||
// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>) | ||
// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8 | ||
|
||
|
||
!qalias = !quant.uniform<i8:f32, 2.0:10> | ||
func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> { | ||
%0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias> | ||
return %0 : tensor<10x!qalias> | ||
} |
Uh oh!
There was an error while loading. Please reload this page.