diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index aee64475171a4..e866ac518dbbc 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -70,6 +70,10 @@ std::unique_ptr createArithUnsignedWhenEquivalentPass(); void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver); +/// Replace signed ops with unsigned ones where they are proven equivalent. +void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver); + /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp index 4edce84bafd41..bebe0b5a7c0b6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp @@ -13,7 +13,8 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace arith { @@ -29,6 +30,9 @@ using namespace mlir::dataflow; /// Succeeds when a value is statically non-negative in that it has a lower /// bound on its value (if it is treated as signed) and that bound is /// non-negative. +// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern +// relies on this. These transformations may not be valid for 32bit index, +// need more investigation. static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { auto *result = solver.lookupState(v); if (!result || result->getValue().isUninitialized()) @@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { } namespace { +class DataFlowListener : public RewriterBase::Listener { +public: + DataFlowListener(DataFlowSolver &s) : s(s) {} + +protected: + void notifyOperationErased(Operation *op) override { + s.eraseState(s.getProgramPointAfter(op)); + for (Value res : op->getResults()) + s.eraseState(res); + } + + DataFlowSolver &s; +}; + template -struct ConvertOpToUnsigned : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertOpToUnsigned final : OpRewritePattern { + ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} - LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, - ConversionPatternRewriter &rw) const override { - rw.replaceOpWithNewOp(op, op->getResultTypes(), - adaptor.getOperands(), op->getAttrs()); + LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override { + if (failed( + staticallyNonNegative(this->solver, static_cast(op)))) + return failure(); + + rw.replaceOpWithNewOp(op, op->getResultTypes(), op->getOperands(), + op->getAttrs()); return success(); } + +private: + DataFlowSolver &solver; }; -struct ConvertCmpIToUnsigned : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertCmpIToUnsigned final : OpRewritePattern { + ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} + + LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override { + if (failed(isCmpIConvertable(this->solver, op))) + return failure(); - LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, - ConversionPatternRewriter &rw) const override { rw.replaceOpWithNewOp(op, toUnsignedPred(op.getPredicate()), op.getLhs(), op.getRhs()); return success(); } + +private: + DataFlowSolver &solver; }; struct ArithUnsignedWhenEquivalentPass : public arith::impl::ArithUnsignedWhenEquivalentBase< ArithUnsignedWhenEquivalentPass> { - /// Implementation structure: first find all equivalent ops and collect them, - /// then perform all the rewrites in a second pass over the target op. This - /// ensures that analysis results are not invalidated during rewriting. + void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); @@ -123,35 +152,32 @@ struct ArithUnsignedWhenEquivalentPass if (failed(solver.initializeAndRun(op))) return signalPassFailure(); - ConversionTarget target(*ctx); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&solver](Operation *op) -> std::optional { - return failed(staticallyNonNegative(solver, op)); - }); - target.addDynamicallyLegalOp( - [&solver](CmpIOp op) -> std::optional { - return failed(isCmpIConvertable(solver, op)); - }); + DataFlowListener listener(solver); RewritePatternSet patterns(ctx); - patterns.add, - ConvertOpToUnsigned, - ConvertOpToUnsigned, - ConvertOpToUnsigned, - ConvertOpToUnsigned, - ConvertOpToUnsigned, - ConvertOpToUnsigned, ConvertCmpIToUnsigned>( - ctx); - - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + populateUnsignedWhenEquivalentPatterns(patterns, solver); + + GreedyRewriteConfig config; + config.listener = &listener; + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) signalPassFailure(); - } } }; } // end anonymous namespace +void mlir::arith::populateUnsignedWhenEquivalentPatterns( + RewritePatternSet &patterns, DataFlowSolver &solver) { + patterns.add, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, ConvertCmpIToUnsigned>( + patterns.getContext(), solver); +} + std::unique_ptr mlir::arith::createArithUnsignedWhenEquivalentPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir index 49bd74cfe9124..0ea69de8b8f9a 100644 --- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir +++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir @@ -12,7 +12,7 @@ // CHECK: arith.cmpi slt // CHECK: arith.cmpi sge // CHECK: arith.cmpi sgt -func.func @not_with_maybe_overflow(%arg0 : i32) { +func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { %ci32_smax = arith.constant 0x7fffffff : i32 %c1 = arith.constant 1 : i32 %c4 = arith.constant 4 : i32 @@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { %10 = arith.cmpi slt, %1, %c4 : i32 %11 = arith.cmpi sge, %1, %c4 : i32 %12 = arith.cmpi sgt, %1, %c4 : i32 - func.return + func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 } // CHECK-LABEL: func @yes_with_no_overflow @@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { // CHECK: arith.cmpi ult // CHECK: arith.cmpi uge // CHECK: arith.cmpi ugt -func.func @yes_with_no_overflow(%arg0 : i32) { +func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { %ci32_almost_smax = arith.constant 0x7ffffffe : i32 %c1 = arith.constant 1 : i32 %c4 = arith.constant 4 : i32 @@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) { %10 = arith.cmpi slt, %1, %c4 : i32 %11 = arith.cmpi sge, %1, %c4 : i32 %12 = arith.cmpi sgt, %1, %c4 : i32 - func.return + func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 } // CHECK-LABEL: func @preserves_structure @@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) { func.func private @external() -> i8 // CHECK-LABEL: @dead_code -func.func @dead_code() { +func.func @dead_code() -> i8 { %0 = call @external() : () -> i8 // CHECK: arith.floordivsi %1 = arith.floordivsi %0, %0 : i8 - return + return %1 : i8 } // Make sure not crash. // CHECK-LABEL: @no_integer_or_index -func.func @no_integer_or_index() { +func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> { // CHECK: arith.cmpi %cst_0 = arith.constant dense<[0]> : vector<1xi32> - %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32> - return + %cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32> + return %cmp : vector<1xi1> } // CHECK-LABEL: @gpu_func @@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem gpu.terminator } return %arg1 : memref<2x32xf32> -} +}