Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
let summary = "Convert ops to destination-passing-style";
let description = [{
Converts ops that operate on tensors but are not in
destination passing style (DPS) to equivalent linalg
generic which is in DPS. e.g.
```mlir
%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
```
gets rewritten as:
```mlir
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xi32>
%0 = tensor.empty(%dim) : tensor<?xf32>
%1 = linalg.generic
{indexing_maps = [#map, #map], iterator_types = ["parallel"]}
ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
^bb0(%in: i32, %out: f32):
%2 = arith.uitofp %in : i32 to f32
linalg.yield %2 : f32
} -> tensor<?xf32>
```
}];
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
tensor::PadOp padOp);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::UIToFPOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::SIToFPOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::FPToUIOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::FPToSIOp op);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::AddIOp op);

FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::AddFOp op);
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
arith::DivFOp op);

/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
/// and linalg.matmul.
///
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ using namespace mlir::transform;
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
// Check if the given operation has the type expected by the pattern.
using OpTy = typename llvm::function_traits<
decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
using OpTy = typename llvm::function_traits<decltype(
&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
Expand Down Expand Up @@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
.Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
Expand Down
105 changes: 104 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "linalg-convert-to-dps"

using namespace mlir;
using namespace mlir::tensor;

Expand Down Expand Up @@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
RankedTensorType resultType = padOp.getResultType();

// Examine the yielded value to decide if a linalg.generic is neede or a
// Examine the yielded value to decide if a linalg.generic is needed or a
// linalg.fill is sufficient.
Value yieldedValue =
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
Expand Down Expand Up @@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
}

namespace {
template <typename OpTy>
FailureOr<Operation *>
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
// reject ops such as `arith.constant` and `arith.select`.
auto numOperands = op->getNumOperands();
if (numOperands == 0 || numOperands > 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only unary and binary we care about.

Copy link
Contributor

@banach-space banach-space Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

only unary and binary we care about.

This is what the code tells me, yes. But it doesn't say why. Also, if that's the case, then this would be clearer to me:

if (numOperands != 1 && numOperands != 2)

EDIT Sorry, posted my comment before noticing that this comment has been updated.

constants don't need dps conversion

Why not? This is assuming that the only purpose of this code is to help with bufferization. That is fine with, but make it clear. Otherwise, "constants don't need dps conversion" sounds very arbitrary (missing "why").

return failure();

// destination passing style rewrite is only for ops on tensor types.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// destination passing style rewrite is only for ops on tensor types.
// Destination passing style rewrite is only for ops on tensor types.

Type resultType = op->getResult(0).getType();
auto tensorType = dyn_cast<RankedTensorType>(resultType);
if (!tensorType)
return failure();

auto loc = op.getLoc();
OpBuilder::InsertionGuard g(rewriter);
auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));

// Create tensor.empty.
Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);

// Create linalg.generic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Create linalg.generic
// Create linalg.generic.

auto rank = tensorType.getRank();
SmallVector<AffineMap> indexingMaps(numOperands + 1,
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType> iteratorTypes(rank,
utils::IteratorType::parallel);
auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
ValueRange{empty}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
if (args.size() == 2) {
res =
builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
.getResult();
} else if (args.size() == 3) {
res = builder.create<OpTy>(loc, args[2].getType(),
ValueRange{args[0], args[1]});
} else
llvm_unreachable("did not expect ops other than nary and binary");
linalg::YieldOp::create(builder, loc, res);
Comment on lines +692 to +694
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the convention is that if one branch if/else branch uses braces, all the other ones should too

Suggested change
} else
llvm_unreachable("did not expect ops other than nary and binary");
linalg::YieldOp::create(builder, loc, res);
} else {
llvm_unreachable("did not expect ops other than nary and binary");
}
linalg::YieldOp::create(builder, loc, res);

});

rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
rewriter.eraseOp(op);
return genericOp.getOperation();
}

template <typename OpTy>
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
Expand All @@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,

} // namespace

#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
RewriterBase &rewriter, OPTY op) { \
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
return rewriteArithInDestinationPassingStyle(rewriter, op); \

}

STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)

STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)

STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can undef this macro here


void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);

patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
}

namespace {
struct LinalgConvertToDPSPass
: public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
using impl::LinalgConvertToDPSPassBase<
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
Comment on lines +746 to +747
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using impl::LinalgConvertToDPSPassBase<
LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
using Base::Base;


void runOnOperation() override;
};

void LinalgConvertToDPSPass::runOnOperation() {

RewritePatternSet patterns(&getContext());
linalg::populateConvertToDestinationStylePatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
Copy link
Member

@kuhar kuhar Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the walk rewrite driver work here? I think this will never have to rewrite the same op twice or visit newly created ops. https://mlir.llvm.org/docs/PatternRewriter/#walk-pattern-rewrite-driver

signalPassFailure();
}
} // namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest creating a new test file with Arith Ops. I see two reasons:

  • Separate tests for Tensor and Arith Ops (these trigger different patterns, so separating them makes sense IMHO).
  • For Tensor Ops, the naming scheme for test functions is @tensor_<op-name>_variant. For Arith, you are using @arith_<binary|unary>_op. So that's a bit inconsistent.

If you prefer to keep everything in one file, could you add a big comment separating Tensor and Arith Ops? Here's an example block comment:

///----------------------------------------------------------------------------------------
/// Tests for tensor.pad
///----------------------------------------------------------------------------------------

Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_unary_op(
// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<64xf32>

func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
%z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
return %z : tensor<64xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_in_destination_passing_style %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @arith_binop(
// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
// CHECK: linalg.yield %[[z]] : f32
// CHECK: return %[[GENERIC]] : tensor<?xf32>

func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use consistent naming.

Suggested change
func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
func.func @arith_bin_op(%x : tensor<?xf32>, %y : tensor<?xf32>)

-> tensor<?xf32> {
%z = arith.addf %x, %y : tensor<?xf32>
return %z : tensor<?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.rewrite_in_destination_passing_style %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s

// CHECK-LABEL: func.func @lower_qcast_to_dps(
// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
//
// CHECK: %[[SITOFP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
//
// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
// CHECK: %{{.*}} = linalg.generic
// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8


!qalias = !quant.uniform<i8:f32, 2.0:10>
func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
%0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
return %0 : tensor<10x!qalias>
}
Loading