Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver);

/// Replace signed ops with unsigned ones where they are proven equivalent.
void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver);

/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();

Expand Down
98 changes: 62 additions & 36 deletions mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace arith {
Expand All @@ -29,6 +30,9 @@ using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is true - all that IntegerRangeAnalysis does is store index as 64-bit. The implementations for various ops on index ought to handle both cases - if they don't might be a bug

// relies on this. These transformations may not be valid for 32bit index,
// need more investigation.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized())
Expand Down Expand Up @@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
}

namespace {
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}

protected:
void notifyOperationErased(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
s.eraseState(res);
}

DataFlowSolver &s;
};

template <typename Signed, typename Unsigned>
struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
using OpConversionPattern<Signed>::OpConversionPattern;
struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<Signed>(context), solver(s) {}

LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
adaptor.getOperands(), op->getAttrs());
LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
if (failed(
staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure where to leave this comment so will just do so here: the backing range analysis makes an assumption that IndexType is 64bits. While mostly harmless to the analysis, this pattern builds on that to apply an optimization that is not valid for a 32bit IndexType.

I don't have a good solution to this but it took me some sleuthing in an earlier incarnation to understand.

Commenting here because "staticallyNonNegative" is only applicable to 64bit IndexType. Probably should at least call for a comment somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment.

return failure();

rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
op->getAttrs());
return success();
}

private:
DataFlowSolver &solver;
};

struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
using OpConversionPattern<CmpIOp>::OpConversionPattern;
struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<CmpIOp>(context), solver(s) {}

LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
if (failed(isCmpIConvertable(this->solver, op)))
return failure();

LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
op.getLhs(), op.getRhs());
return success();
}

private:
DataFlowSolver &solver;
};

struct ArithUnsignedWhenEquivalentPass
: public arith::impl::ArithUnsignedWhenEquivalentBase<
ArithUnsignedWhenEquivalentPass> {
/// Implementation structure: first find all equivalent ops and collect them,
/// then perform all the rewrites in a second pass over the target op. This
/// ensures that analysis results are not invalidated during rewriting.

void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
Expand All @@ -123,35 +152,32 @@ struct ArithUnsignedWhenEquivalentPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();

ConversionTarget target(*ctx);
target.addLegalDialect<ArithDialect>();
target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
MinSIOp, MaxSIOp, ExtSIOp>(
[&solver](Operation *op) -> std::optional<bool> {
return failed(staticallyNonNegative(solver, op));
});
target.addDynamicallyLegalOp<CmpIOp>(
[&solver](CmpIOp op) -> std::optional<bool> {
return failed(isCmpIConvertable(solver, op));
});
DataFlowListener listener(solver);

RewritePatternSet patterns(ctx);
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
ctx);

if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
populateUnsignedWhenEquivalentPatterns(patterns, solver);

GreedyRewriteConfig config;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect to have to iterate to convergence here? Otherwise can we set options to limit to a single iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, it should finish in single iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, adding maxIterations = 1 causing it to fail to converge. I think I will leave things as is for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably also need to change it to top down iteration in order to get the same convergence behavior as before.

I've definitely seen combined passes that are doing other optimizations along with unsigned conversions require multiple iterations to converge (and be more efficient with bottom up iteration), but I expect that this simple test pass just needs one top down pass through the IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails to converge (i.e. applyPatternsAndFoldGreedily returns failure) even with

    config.maxIterations = 1;
    config.useTopDownTraversal = true;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly test pass anyway so I don't see much problem here. Downstream we plan to combine it with other patterns. And using greedy driver here may not be ideal, but it's still better than current dialect conversion driver.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a test pass, it should be moved to the test folder and be named accordingly though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't intended as a test pass - you're meant to be able to run this (barring the usual philosophical disagreements about even having non-test passes upstream)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this code using the dialect converter because I wanted a one-shot "walk this function exactly once and apply matching patterns"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't intended as a test pass

Well, then my point stands: we shouldn't involve the greedy rewriter here.

config.listener = &listener;

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
}
};
} // end anonymous namespace

void mlir::arith::populateUnsignedWhenEquivalentPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
patterns.getContext(), solver);
}

std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
}
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// CHECK: arith.cmpi slt
// CHECK: arith.cmpi sge
// CHECK: arith.cmpi sgt
func.func @not_with_maybe_overflow(%arg0 : i32) {
func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_smax = arith.constant 0x7fffffff : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
Expand All @@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
func.return
func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}

// CHECK-LABEL: func @yes_with_no_overflow
Expand All @@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
// CHECK: arith.cmpi ult
// CHECK: arith.cmpi uge
// CHECK: arith.cmpi ugt
func.func @yes_with_no_overflow(%arg0 : i32) {
func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_almost_smax = arith.constant 0x7ffffffe : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
Expand All @@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
func.return
func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}

// CHECK-LABEL: func @preserves_structure
Expand Down Expand Up @@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
func.func private @external() -> i8

// CHECK-LABEL: @dead_code
func.func @dead_code() {
func.func @dead_code() -> i8 {
%0 = call @external() : () -> i8
// CHECK: arith.floordivsi
%1 = arith.floordivsi %0, %0 : i8
return
return %1 : i8
}

// Make sure not crash.
// CHECK-LABEL: @no_integer_or_index
func.func @no_integer_or_index() {
func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> {
// CHECK: arith.cmpi
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
%cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
return
%cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
return %cmp : vector<1xi1>
}

// CHECK-LABEL: @gpu_func
Expand All @@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
gpu.terminator
}
return %arg1 : memref<2x32xf32>
}
}
Loading