From 065c3fbfba855a419313d2a9e003a49298420a47 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 15 Oct 2024 16:32:32 +0200 Subject: [PATCH 01/20] [mlir] Add `arith-int-range-narrowing` pass This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`. We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program. They ideally should be unified in the future, but it's a task for the future. --- .../mlir/Dialect/Arith/Transforms/Passes.h | 6 + .../mlir/Dialect/Arith/Transforms/Passes.td | 19 ++ .../Transforms/IntRangeOptimizations.cpp | 260 ++++++++++++++++++ .../Dialect/Arith/int-range-narrowing.mlir | 57 ++++ 4 files changed, 342 insertions(+) create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index e866ac518dbbc..b66bfbb23bc60 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -77,6 +77,12 @@ void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); +/// Add patterns for int range based norrowing. +void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver, + unsigned targetBitwidth); + +// TODO: merge these two narrowing passes. /// Add patterns for integer bitwidth narrowing. void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, const ArithIntNarrowingOptions &options); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 1517f71f1a7c9..8c565f6489638 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> { ]; } +def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> { + let summary = "Reduce integer operations bitwidth based on integer range analysis"; + let description = [{ + This pass runs integer range analysis and tries to narrow arith ops to the + specified bitwidth based on its results. + }]; + + let options = [ + Option<"targetBitwidth", "target-bitwidth", "unsigned", + /*default=*/"32", "Target bitwidth this pass will try to narrow to">, + ]; + + // Explicitly depend on "arith" because this pass could create operations in + // `arith` out of thin air in some cases. + let dependentDialects = [ + "::mlir::arith::ArithDialect" + ]; +} + def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> { let summary = "Emulate operations on unsupported floats with extf/truncf"; let description = [{ diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index d494bba081f80..005033dfd5e11 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -15,8 +15,10 @@ #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -24,6 +26,9 @@ namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" + +#define GEN_PASS_DEF_ARITHINTRANGENARROWING +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith using namespace mlir; @@ -190,6 +195,223 @@ struct DeleteTrivialRem : public OpRewritePattern { DataFlowSolver &solver; }; +static Type checkArithType(Type type, unsigned targetBitwidth) { + type = getElementTypeOrSelf(type); + if (isa(type)) + return type; + + if (auto intType = dyn_cast(type)) + if (intType.getWidth() > targetBitwidth) + return type; + + return nullptr; +} + +static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { + if (op->getNumOperands() == 0 || op->getNumResults() == 0) + return nullptr; + + Type type; + for (auto range : + {ValueRange(op->getOperands()), ValueRange(op->getResults())}) { + for (Value val : range) { + if (!type) { + type = val.getType(); + continue; + } else if (type != val.getType()) { + return nullptr; + } + } + } + + return checkArithType(type, targetBitwidth); +} + +static std::optional getOperandsRange(DataFlowSolver &solver, + ValueRange results) { + std::optional ret; + for (Value value : results) { + auto *maybeInferredRange = + solver.lookupState(value); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return std::nullopt; + + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + + if (!ret) { + ret = inferredRange; + } else { + ret = ret->rangeUnion(inferredRange); + } + } + return ret; +} + +static Type getTargetType(Type srcType, unsigned targetBitwidth) { + auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); + if (auto shaped = dyn_cast(srcType)) + return shaped.clone(dstType); + + assert(srcType.isIntOrIndex() && "Invalid src type"); + return dstType; +} + +static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, + APInt umin, APInt umax) { + auto sge = [](APInt val1, APInt val2) -> bool { + unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); + val1 = val1.sext(width); + val2 = val2.sext(width); + return val1.sge(val2); + }; + auto sle = [](APInt val1, APInt val2) -> bool { + unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); + val1 = val1.sext(width); + val2 = val2.sext(width); + return val1.sle(val2); + }; + auto uge = [](APInt val1, APInt val2) -> bool { + unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); + val1 = val1.zext(width); + val2 = val2.zext(width); + return val1.uge(val2); + }; + auto ule = [](APInt val1, APInt val2) -> bool { + unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); + val1 = val1.zext(width); + val2 = val2.zext(width); + return val1.ule(val2); + }; + return sge(range.smin(), smin) && sle(range.smax(), smax) && + uge(range.umin(), umin) && ule(range.umax(), umax); +} + +static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { + Type srcType = src.getType(); + assert(srcType.isIntOrIndex() && "Invalid src type"); + assert(dstType.isIntOrIndex() && "Invalid dst type"); + if (srcType == dstType) + return src; + + if (isa(srcType) || isa(dstType)) + return builder.create(loc, dstType, src); + + auto srcInt = cast(srcType); + auto dstInt = cast(dstType); + if (dstInt.getWidth() < srcInt.getWidth()) { + return builder.create(loc, dstType, src); + } else { + return builder.create(loc, dstType, src); + } +} + +struct NarrowElementwise final + : public OpTraitRewritePattern { + NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target) + : OpTraitRewritePattern(context), solver(s), + targetBitwidth(target) {} + + using OpTraitRewritePattern::OpTraitRewritePattern; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + Type srcType = checkElementwiseOpType(op, targetBitwidth); + if (!srcType) + return failure(); + + std::optional range = + getOperandsRange(solver, op->getResults()); + if (!range) + return failure(); + + // We are truncating op args to the desired bitwidth before the op and then + // extending op results back to the original width after. + // extui and exti will produce different results for negative values, so + // limit signed range to non-negative values. + auto smin = APInt::getZero(targetBitwidth); + auto smax = APInt::getSignedMaxValue(targetBitwidth); + auto umin = APInt::getMinValue(targetBitwidth); + auto umax = APInt::getMaxValue(targetBitwidth); + if (!checkRange(*range, smin, smax, umin, umax)) + return failure(); + + Type targetType = getTargetType(srcType, targetBitwidth); + if (targetType == srcType) + return failure(); + + Location loc = op->getLoc(); + IRMapping mapping; + for (Value arg : op->getOperands()) { + Value newArg = doCast(rewriter, loc, arg, targetType); + mapping.map(arg, newArg); + } + + Operation *newOp = rewriter.clone(*op, mapping); + rewriter.modifyOpInPlace(newOp, [&]() { + for (OpResult res : newOp->getResults()) { + res.setType(targetType); + } + }); + SmallVector newResults; + for (Value res : newOp->getResults()) + newResults.emplace_back(doCast(rewriter, loc, res, srcType)); + + rewriter.replaceOp(op, newResults); + return success(); + } + +private: + DataFlowSolver &solver; + unsigned targetBitwidth; +}; + +struct NarrowCmpi final : public OpRewritePattern { + NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s, + unsigned target) + : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {} + + LogicalResult matchAndRewrite(arith::CmpIOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Type srcType = checkArithType(lhs.getType(), targetBitwidth); + if (!srcType) + return failure(); + + std::optional range = + getOperandsRange(solver, {lhs, rhs}); + if (!range) + return failure(); + + auto smin = APInt::getSignedMinValue(targetBitwidth); + auto smax = APInt::getSignedMaxValue(targetBitwidth); + auto umin = APInt::getMinValue(targetBitwidth); + auto umax = APInt::getMaxValue(targetBitwidth); + if (!checkRange(*range, smin, smax, umin, umax)) + return failure(); + + Type targetType = getTargetType(srcType, targetBitwidth); + if (targetType == srcType) + return failure(); + + Location loc = op->getLoc(); + IRMapping mapping; + for (Value arg : op->getOperands()) { + Value newArg = doCast(rewriter, loc, arg, targetType); + mapping.map(arg, newArg); + } + + Operation *newOp = rewriter.clone(*op, mapping); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + +private: + DataFlowSolver &solver; + unsigned targetBitwidth; +}; + struct IntRangeOptimizationsPass : public arith::impl::ArithIntRangeOptsBase { @@ -214,6 +436,32 @@ struct IntRangeOptimizationsPass signalPassFailure(); } }; + +struct IntRangeNarrowingPass + : public arith::impl::ArithIntRangeNarrowingBase { + using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + DataFlowListener listener(solver); + + RewritePatternSet patterns(ctx); + populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth); + + GreedyRewriteConfig config; + config.listener = &listener; + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + signalPassFailure(); + } +}; } // namespace void mlir::arith::populateIntRangeOptimizationsPatterns( @@ -222,6 +470,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns( DeleteTrivialRem>(patterns.getContext(), solver); } +void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver, + unsigned targetBitwidth) { + // Cmpi uses args ranges instead of results, run it with higher benefit, + // as its argumens can be potentially replaced. + patterns.add(patterns.getContext(), /*benefit*/ 10, solver, + targetBitwidth); + + patterns.add(patterns.getContext(), solver, + targetBitwidth); +} + std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir new file mode 100644 index 0000000000000..d85cb3384061b --- /dev/null +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s + +// Do not truncate negative values +// CHECK-LABEL: func @test_addi_neg +// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: return %[[RES]] : index +func.func @test_addi_neg() -> index { + %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index + %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index + %2 = arith.addi %0, %1 : index + return %2 : index +} + +// CHECK-LABEL: func @test_addi +// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index +// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32 +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32 +// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32 +// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index +// CHECK: return %[[RES_CASTED]] : index +func.func @test_addi() -> index { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index + %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index + %2 = arith.addi %0, %1 : index + return %2 : index +} + + +// CHECK-LABEL: func @test_addi_i64 +// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64 +// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64 +// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32 +// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32 +// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32 +// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64 +// CHECK: return %[[RES_CASTED]] : i64 +func.func @test_addi_i64() -> i64 { + %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64 + %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64 + %2 = arith.addi %0, %1 : i64 + return %2 : i64 +} + +// CHECK-LABEL: func @test_cmpi +// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index +// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32 +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32 +// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32 +// CHECK: return %[[RES]] : i1 +func.func @test_cmpi() -> i1 { + %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %2 = arith.cmpi slt, %0, %1 : index + return %2 : i1 +} From ed4920c90d685aff762063692c11381727afdd44 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 15 Oct 2024 18:59:07 +0200 Subject: [PATCH 02/20] typo --- mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index b66bfbb23bc60..da7ded699f21c 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -77,7 +77,7 @@ void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); -/// Add patterns for int range based norrowing. +/// Add patterns for int range based narrowing. void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, unsigned targetBitwidth); From eb91b30450e72b8ac8b17f198a941b255af01b5d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 13:05:29 +0200 Subject: [PATCH 03/20] use list of bitwidths instead of 1 --- .../mlir/Dialect/Arith/Transforms/Passes.h | 2 +- .../mlir/Dialect/Arith/Transforms/Passes.td | 4 +- .../Transforms/IntRangeOptimizations.cpp | 141 ++++++++++-------- .../Dialect/Arith/int-range-narrowing.mlir | 2 +- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index da7ded699f21c..b6a87e88c6efb 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -80,7 +80,7 @@ std::unique_ptr createIntRangeOptimizationsPass(); /// Add patterns for int range based narrowing. void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, - unsigned targetBitwidth); + ArrayRef bitwidthsSupported); // TODO: merge these two narrowing passes. /// Add patterns for integer bitwidth narrowing. diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 8c565f6489638..898d74249af61 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -58,8 +58,8 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> { }]; let options = [ - Option<"targetBitwidth", "target-bitwidth", "unsigned", - /*default=*/"32", "Target bitwidth this pass will try to narrow to">, + ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned", + "Integer bitwidths supported">, ]; // Explicitly depend on "arith" because this pass could create operations in diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 005033dfd5e11..8c651076df2e5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -308,108 +308,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { struct NarrowElementwise final : public OpTraitRewritePattern { - NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target) + NarrowElementwise(MLIRContext *context, DataFlowSolver &s, + ArrayRef target) : OpTraitRewritePattern(context), solver(s), - targetBitwidth(target) {} + targetBitwidths(target) {} using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - Type srcType = checkElementwiseOpType(op, targetBitwidth); - if (!srcType) - return failure(); std::optional range = getOperandsRange(solver, op->getResults()); if (!range) return failure(); - // We are truncating op args to the desired bitwidth before the op and then - // extending op results back to the original width after. - // extui and exti will produce different results for negative values, so - // limit signed range to non-negative values. - auto smin = APInt::getZero(targetBitwidth); - auto smax = APInt::getSignedMaxValue(targetBitwidth); - auto umin = APInt::getMinValue(targetBitwidth); - auto umax = APInt::getMaxValue(targetBitwidth); - if (!checkRange(*range, smin, smax, umin, umax)) - return failure(); + for (unsigned targetBitwidth : targetBitwidths) { + Type srcType = checkElementwiseOpType(op, targetBitwidth); + if (!srcType) + continue; - Type targetType = getTargetType(srcType, targetBitwidth); - if (targetType == srcType) - return failure(); + // We are truncating op args to the desired bitwidth before the op and + // then extending op results back to the original width after. extui and + // exti will produce different results for negative values, so limit + // signed range to non-negative values. + auto smin = APInt::getZero(targetBitwidth); + auto smax = APInt::getSignedMaxValue(targetBitwidth); + auto umin = APInt::getMinValue(targetBitwidth); + auto umax = APInt::getMaxValue(targetBitwidth); + if (!checkRange(*range, smin, smax, umin, umax)) + continue; - Location loc = op->getLoc(); - IRMapping mapping; - for (Value arg : op->getOperands()) { - Value newArg = doCast(rewriter, loc, arg, targetType); - mapping.map(arg, newArg); - } + Type targetType = getTargetType(srcType, targetBitwidth); + if (targetType == srcType) + continue; - Operation *newOp = rewriter.clone(*op, mapping); - rewriter.modifyOpInPlace(newOp, [&]() { - for (OpResult res : newOp->getResults()) { - res.setType(targetType); + Location loc = op->getLoc(); + IRMapping mapping; + for (Value arg : op->getOperands()) { + Value newArg = doCast(rewriter, loc, arg, targetType); + mapping.map(arg, newArg); } - }); - SmallVector newResults; - for (Value res : newOp->getResults()) - newResults.emplace_back(doCast(rewriter, loc, res, srcType)); - rewriter.replaceOp(op, newResults); - return success(); + Operation *newOp = rewriter.clone(*op, mapping); + rewriter.modifyOpInPlace(newOp, [&]() { + for (OpResult res : newOp->getResults()) { + res.setType(targetType); + } + }); + SmallVector newResults; + for (Value res : newOp->getResults()) + newResults.emplace_back(doCast(rewriter, loc, res, srcType)); + + rewriter.replaceOp(op, newResults); + return success(); + } + return failure(); } private: DataFlowSolver &solver; - unsigned targetBitwidth; + SmallVector targetBitwidths; }; struct NarrowCmpi final : public OpRewritePattern { NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s, - unsigned target) - : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {} + ArrayRef target) + : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) { + } LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - Type srcType = checkArithType(lhs.getType(), targetBitwidth); - if (!srcType) - return failure(); - std::optional range = getOperandsRange(solver, {lhs, rhs}); if (!range) return failure(); - auto smin = APInt::getSignedMinValue(targetBitwidth); - auto smax = APInt::getSignedMaxValue(targetBitwidth); - auto umin = APInt::getMinValue(targetBitwidth); - auto umax = APInt::getMaxValue(targetBitwidth); - if (!checkRange(*range, smin, smax, umin, umax)) - return failure(); + for (unsigned targetBitwidth : targetBitwidths) { + Type srcType = checkArithType(lhs.getType(), targetBitwidth); + if (!srcType) + continue; - Type targetType = getTargetType(srcType, targetBitwidth); - if (targetType == srcType) - return failure(); + auto smin = APInt::getSignedMinValue(targetBitwidth); + auto smax = APInt::getSignedMaxValue(targetBitwidth); + auto umin = APInt::getMinValue(targetBitwidth); + auto umax = APInt::getMaxValue(targetBitwidth); + if (!checkRange(*range, smin, smax, umin, umax)) + continue; - Location loc = op->getLoc(); - IRMapping mapping; - for (Value arg : op->getOperands()) { - Value newArg = doCast(rewriter, loc, arg, targetType); - mapping.map(arg, newArg); - } + Type targetType = getTargetType(srcType, targetBitwidth); + if (targetType == srcType) + continue; - Operation *newOp = rewriter.clone(*op, mapping); - rewriter.replaceOp(op, newOp->getResults()); - return success(); + Location loc = op->getLoc(); + IRMapping mapping; + for (Value arg : op->getOperands()) { + Value newArg = doCast(rewriter, loc, arg, targetType); + mapping.map(arg, newArg); + } + + Operation *newOp = rewriter.clone(*op, mapping); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + return failure(); } private: DataFlowSolver &solver; - unsigned targetBitwidth; + SmallVector targetBitwidths; }; struct IntRangeOptimizationsPass @@ -453,7 +462,7 @@ struct IntRangeNarrowingPass DataFlowListener listener(solver); RewritePatternSet patterns(ctx); - populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth); + populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); GreedyRewriteConfig config; config.listener = &listener; @@ -470,16 +479,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns( DeleteTrivialRem>(patterns.getContext(), solver); } -void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, - DataFlowSolver &solver, - unsigned targetBitwidth) { +void mlir::arith::populateIntRangeNarrowingPatterns( + RewritePatternSet &patterns, DataFlowSolver &solver, + ArrayRef bitwidthsSupported) { // Cmpi uses args ranges instead of results, run it with higher benefit, // as its argumens can be potentially replaced. patterns.add(patterns.getContext(), /*benefit*/ 10, solver, - targetBitwidth); + bitwidthsSupported); patterns.add(patterns.getContext(), solver, - targetBitwidth); + bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index d85cb3384061b..6283284567dce 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s +// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s // Do not truncate negative values // CHECK-LABEL: func @test_addi_neg From 170157fe1606a894310f66df012ff4aa25a5a757 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 13:12:00 +0200 Subject: [PATCH 04/20] update tests --- .../Dialect/Arith/int-range-narrowing.mlir | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 6283284567dce..8b73f02fd214a 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s +// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s // Do not truncate negative values // CHECK-LABEL: func @test_addi_neg @@ -14,10 +14,10 @@ func.func @test_addi_neg() -> index { // CHECK-LABEL: func @test_addi // CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index // CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index -// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32 -// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32 -// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32 -// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8 +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8 +// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8 +// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index // CHECK: return %[[RES_CASTED]] : index func.func @test_addi() -> index { %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index @@ -30,10 +30,10 @@ func.func @test_addi() -> index { // CHECK-LABEL: func @test_addi_i64 // CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64 // CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64 -// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32 -// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32 -// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32 -// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64 +// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8 +// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8 +// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8 +// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64 // CHECK: return %[[RES_CASTED]] : i64 func.func @test_addi_i64() -> i64 { %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64 @@ -45,9 +45,9 @@ func.func @test_addi_i64() -> i64 { // CHECK-LABEL: func @test_cmpi // CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index // CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index -// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32 -// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32 -// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32 +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8 +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8 +// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8 // CHECK: return %[[RES]] : i1 func.func @test_cmpi() -> i1 { %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index From 0af7e9e67cf04d4cdd3f5882a1992ef5566a74a0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 14:13:06 +0200 Subject: [PATCH 05/20] more tests --- .../Dialect/Arith/int-range-narrowing.mlir | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 8b73f02fd214a..cd0a4c449913e 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s +//===----------------------------------------------------------------------===// +// Some basic tests +//===----------------------------------------------------------------------===// + // Do not truncate negative values // CHECK-LABEL: func @test_addi_neg // CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index @@ -55,3 +59,141 @@ func.func @test_cmpi() -> i1 { %2 = arith.cmpi slt, %0, %1 : index return %2 : i1 } + +//===----------------------------------------------------------------------===// +// arith.addi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @addi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @addi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.addi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because we cannot reduce the bitwidth +// below i16, given the pass options set. +// +// CHECK-LABEL: func.func @addi_extsi_i16 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16 +// CHECK-NEXT: return %[[ADD]] : i16 +func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 { + %a = arith.extsi %lhs : i8 to i16 + %b = arith.extsi %rhs : i8 to i16 + %r = arith.addi %a, %b : i16 + return %r : i16 +} + +//===----------------------------------------------------------------------===// +// arith.subi +//===----------------------------------------------------------------------===// + +// This patterns should only apply to `arith.subi` ops with sign-extended +// arguments. +// +// CHECK-LABEL: func.func @subi_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[SUB]] : i32 +func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @subi_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[ADD]] : i32 +func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.subi %a, %b : i32 + return %r : i32 +} + +//===----------------------------------------------------------------------===// +// arith.muli +//===----------------------------------------------------------------------===// + +// TODO: This should be optimized into i16 +// CHECK-LABEL: func.func @muli_extui_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i24 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i24 to i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extui %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// We do not expect this case to be optimized because given n-bit operands, +// arith.muli produces 2n bits of result. +// +// CHECK-LABEL: func.func @muli_extsi_i32 +// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32 +// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RET]] : i32 +func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 { + %a = arith.extsi %lhs : i16 to i32 + %b = arith.extsi %rhs : i16 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} + +// This case should not get optimized because of mixed extensions. +// +// CHECK-LABEL: func.func @muli_mixed_ext_i8 +// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) +// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 +// CHECK-NEXT: return %[[MUL]] : i32 +func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { + %a = arith.extsi %lhs : i8 to i32 + %b = arith.extui %rhs : i8 to i32 + %r = arith.muli %a, %b : i32 + return %r : i32 +} From 9616513bab14facd97e1302b07fd2df478912a34 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 18:06:43 +0200 Subject: [PATCH 06/20] remove the old pass --- .../mlir/Dialect/Arith/Transforms/Passes.h | 5 - .../mlir/Dialect/Arith/Transforms/Passes.td | 14 - .../Dialect/Arith/Transforms/CMakeLists.txt | 1 - .../Dialect/Arith/Transforms/IntNarrowing.cpp | 790 -------------- .../Arith/int-narrowing-invalid-options.mlir | 16 - mlir/test/Dialect/Arith/int-narrowing.mlir | 997 ------------------ mlir/test/Dialect/Linalg/int-narrowing.mlir | 147 --- 7 files changed, 1970 deletions(-) delete mode 100644 mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp delete mode 100644 mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir delete mode 100644 mlir/test/Dialect/Arith/int-narrowing.mlir delete mode 100644 mlir/test/Dialect/Linalg/int-narrowing.mlir diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index b6a87e88c6efb..58dce89fdb578 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -82,11 +82,6 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported); -// TODO: merge these two narrowing passes. -/// Add patterns for integer bitwidth narrowing. -void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, - const ArithIntNarrowingOptions &options); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 898d74249af61..98f90d120fa1c 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -111,18 +111,4 @@ def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> { let dependentDialects = ["vector::VectorDialect"]; } -def ArithIntNarrowing : Pass<"arith-int-narrowing"> { - let summary = "Reduce integer operation bitwidth"; - let description = [{ - Reduce bitwidths of integer types used in arith operations. This pass - prefers the narrowest available integer bitwidths that are guaranteed to - produce the same results. - }]; - let dependentDialects = ["vector::VectorDialect"]; - let options = [ - ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned", - "Integer bitwidths supported">, - ]; - } - #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 93a004d31916f..912853871b7f8 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_dialect_library(MLIRArithTransforms EmulateWideInt.cpp EmulateNarrowType.cpp ExpandOps.cpp - IntNarrowing.cpp IntRangeOptimizations.cpp ReifyValueBounds.cpp UnsignedWhenEquivalent.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp deleted file mode 100644 index b61218bb7f1af..0000000000000 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ /dev/null @@ -1,790 +0,0 @@ -//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/Transforms/Passes.h" - -#include "mlir/Analysis/Presburger/IntegerRelation.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Transforms/Transforms.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/ValueBoundsOpInterface.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include -#include - -namespace mlir::arith { -#define GEN_PASS_DEF_ARITHINTNARROWING -#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" -} // namespace mlir::arith - -namespace mlir::arith { -namespace { -//===----------------------------------------------------------------------===// -// Common Helpers -//===----------------------------------------------------------------------===// - -/// The base for integer bitwidth narrowing patterns. -template -struct NarrowingPattern : OpRewritePattern { - NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(ctx, benefit), - supportedBitwidths(options.bitwidthsSupported.begin(), - options.bitwidthsSupported.end()) { - assert(!supportedBitwidths.empty() && "Invalid options"); - assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth"); - llvm::sort(supportedBitwidths); - } - - FailureOr - getNarrowestCompatibleBitwidth(unsigned bitsRequired) const { - for (unsigned candidate : supportedBitwidths) - if (candidate >= bitsRequired) - return candidate; - - return failure(); - } - - /// Returns the narrowest supported type that fits `bitsRequired`. - FailureOr getNarrowType(unsigned bitsRequired, Type origTy) const { - assert(origTy); - FailureOr bestBitwidth = - getNarrowestCompatibleBitwidth(bitsRequired); - if (failed(bestBitwidth)) - return failure(); - - Type elemTy = getElementTypeOrSelf(origTy); - if (!isa(elemTy)) - return failure(); - - auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth); - if (newElemTy == elemTy) - return failure(); - - if (origTy == elemTy) - return newElemTy; - - if (auto shapedTy = dyn_cast(origTy)) - if (dyn_cast(shapedTy.getElementType())) - return shapedTy.clone(shapedTy.getShape(), newElemTy); - - return failure(); - } - -private: - // Supported integer bitwidths in the ascending order. - llvm::SmallVector supportedBitwidths; -}; - -/// Returns the integer bitwidth required to represent `type`. -FailureOr calculateBitsRequired(Type type) { - assert(type); - if (auto intTy = dyn_cast(getElementTypeOrSelf(type))) - return intTy.getWidth(); - - return failure(); -} - -enum class ExtensionKind { Sign, Zero }; - -/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away -/// the exact op type. Exposes helper functions to query the types, operands, -/// and the result. This is so that we can handle both extension kinds without -/// needing to use templates or branching. -class ExtensionOp { -public: - /// Attemps to create a new extension op from `op`. Returns an extension op - /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure - /// otherwise. - static FailureOr from(Operation *op) { - if (dyn_cast_or_null(op)) - return ExtensionOp{op, ExtensionKind::Sign}; - if (dyn_cast_or_null(op)) - return ExtensionOp{op, ExtensionKind::Zero}; - - return failure(); - } - - ExtensionOp(const ExtensionOp &) = default; - ExtensionOp &operator=(const ExtensionOp &) = default; - - /// Creates a new extension op of the same kind. - Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType, - Value in) { - if (kind == ExtensionKind::Sign) - return rewriter.create(loc, newType, in); - - return rewriter.create(loc, newType, in); - } - - /// Replaces `toReplace` with a new extension op of the same kind. - void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace, - Value in) { - assert(toReplace->getNumResults() == 1); - Type newType = toReplace->getResult(0).getType(); - Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in); - rewriter.replaceOp(toReplace, newOp->getResult(0)); - } - - ExtensionKind getKind() { return kind; } - - Value getResult() { return op->getResult(0); } - Value getIn() { return op->getOperand(0); } - - Type getType() { return getResult().getType(); } - Type getElementType() { return getElementTypeOrSelf(getType()); } - Type getInType() { return getIn().getType(); } - Type getInElementType() { return getElementTypeOrSelf(getInType()); } - -private: - ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) { - assert(op); - assert((isa(op)) && "Not an extension op"); - } - Operation *op = nullptr; - ExtensionKind kind = {}; -}; - -/// Returns the integer bitwidth required to represent `value`. -unsigned calculateBitsRequired(const APInt &value, - ExtensionKind lookThroughExtension) { - // For unsigned values, we only need the active bits. As a special case, zero - // requires one bit. - if (lookThroughExtension == ExtensionKind::Zero) - return std::max(value.getActiveBits(), 1u); - - // If a signed value is nonnegative, we need one extra bit for the sign. - if (value.isNonNegative()) - return value.getActiveBits() + 1; - - // For the signed min, we need all the bits. - if (value.isMinSignedValue()) - return value.getBitWidth(); - - // For negative values, we need all the non-sign bits and one extra bit for - // the sign. - return value.getBitWidth() - value.getNumSignBits() + 1; -} - -/// Returns the integer bitwidth required to represent `value`. -/// Looks through either sign- or zero-extension as specified by -/// `lookThroughExtension`. -FailureOr calculateBitsRequired(Value value, - ExtensionKind lookThroughExtension) { - // Handle constants. - if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) { - if (auto intAttr = dyn_cast(attr)) - return calculateBitsRequired(intAttr.getValue(), lookThroughExtension); - - if (auto elemsAttr = dyn_cast(attr)) { - if (elemsAttr.getElementType().isIntOrIndex()) { - if (elemsAttr.isSplat()) - return calculateBitsRequired(elemsAttr.getSplatValue(), - lookThroughExtension); - - unsigned maxBits = 1; - for (const APInt &elemValue : elemsAttr.getValues()) - maxBits = std::max( - maxBits, calculateBitsRequired(elemValue, lookThroughExtension)); - return maxBits; - } - } - } - - if (lookThroughExtension == ExtensionKind::Sign) { - if (auto sext = value.getDefiningOp()) - return calculateBitsRequired(sext.getIn().getType()); - } else if (lookThroughExtension == ExtensionKind::Zero) { - if (auto zext = value.getDefiningOp()) - return calculateBitsRequired(zext.getIn().getType()); - } - - // If nothing else worked, return the type requirements for this element type. - return calculateBitsRequired(value.getType()); -} - -/// Base pattern for arith binary ops. -/// Example: -/// ``` -/// %lhs = arith.extsi %a : i8 to i32 -/// %rhs = arith.extsi %b : i8 to i32 -/// %r = arith.addi %lhs, %rhs : i32 -/// ==> -/// %lhs = arith.extsi %a : i8 to i16 -/// %rhs = arith.extsi %b : i8 to i16 -/// %add = arith.addi %lhs, %rhs : i16 -/// %r = arith.extsi %add : i16 to i32 -/// ``` -template -struct BinaryOpNarrowingPattern : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - /// Returns the number of bits required to represent the full result, assuming - /// that both operands are `operandBits`-wide. Derived classes must implement - /// this, taking into account `BinaryOp` semantics. - virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; - - /// Customization point for patterns that should only apply with - /// zero/sign-extension ops as arguments. - virtual bool isSupported(ExtensionOp) const { return true; } - - LogicalResult matchAndRewrite(BinaryOp op, - PatternRewriter &rewriter) const final { - Type origTy = op.getType(); - FailureOr resultBits = calculateBitsRequired(origTy); - if (failed(resultBits)) - return failure(); - - // For the optimization to apply, we expect the lhs to be an extension op, - // and for the rhs to either be the same extension op or a constant. - FailureOr ext = ExtensionOp::from(op.getLhs().getDefiningOp()); - if (failed(ext) || !isSupported(*ext)) - return failure(); - - FailureOr lhsBitsRequired = - calculateBitsRequired(ext->getIn(), ext->getKind()); - if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits) - return failure(); - - FailureOr rhsBitsRequired = - calculateBitsRequired(op.getRhs(), ext->getKind()); - if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits) - return failure(); - - // Negotiate a common bit requirements for both lhs and rhs, accounting for - // the result requiring more bits than the operands. - unsigned commonBitsRequired = - getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired)); - FailureOr narrowTy = this->getNarrowType(commonBitsRequired, origTy); - if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits) - return failure(); - - Location loc = op.getLoc(); - Value newLhs = - rewriter.createOrFold(loc, *narrowTy, op.getLhs()); - Value newRhs = - rewriter.createOrFold(loc, *narrowTy, op.getRhs()); - Value newAdd = rewriter.create(loc, newLhs, newRhs); - ext->recreateAndReplace(rewriter, op, newAdd); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// AddIOp Pattern -//===----------------------------------------------------------------------===// - -struct AddIPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - // Addition may require one extra bit for the result. - // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return operandBits + 1; - } -}; - -//===----------------------------------------------------------------------===// -// SubIOp Pattern -//===----------------------------------------------------------------------===// - -struct SubIPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - // This optimization only applies to signed arguments. - bool isSupported(ExtensionOp ext) const override { - return ext.getKind() == ExtensionKind::Sign; - } - - // Subtraction may require one extra bit for the result. - // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return operandBits + 1; - } -}; - -//===----------------------------------------------------------------------===// -// MulIOp Pattern -//===----------------------------------------------------------------------===// - -struct MulIPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - // Multiplication may require up double the operand bits. - // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return 2 * operandBits; - } -}; - -//===----------------------------------------------------------------------===// -// DivSIOp Pattern -//===----------------------------------------------------------------------===// - -struct DivSIPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - // This optimization only applies to signed arguments. - bool isSupported(ExtensionOp ext) const override { - return ext.getKind() == ExtensionKind::Sign; - } - - // Unlike multiplication, signed division requires only one more result bit. - // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return operandBits + 1; - } -}; - -//===----------------------------------------------------------------------===// -// DivUIOp Pattern -//===----------------------------------------------------------------------===// - -struct DivUIPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - // This optimization only applies to unsigned arguments. - bool isSupported(ExtensionOp ext) const override { - return ext.getKind() == ExtensionKind::Zero; - } - - // Unsigned division does not require any extra result bits. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return operandBits; - } -}; - -//===----------------------------------------------------------------------===// -// Min/Max Patterns -//===----------------------------------------------------------------------===// - -template -struct MinMaxPattern final : BinaryOpNarrowingPattern { - using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; - - bool isSupported(ExtensionOp ext) const override { - return ext.getKind() == Kind; - } - - // Min/max returns one of the arguments and does not require any extra result - // bits. - unsigned getResultBitsProduced(unsigned operandBits) const override { - return operandBits; - } -}; -using MaxSIPattern = MinMaxPattern; -using MaxUIPattern = MinMaxPattern; -using MinSIPattern = MinMaxPattern; -using MinUIPattern = MinMaxPattern; - -//===----------------------------------------------------------------------===// -// *IToFPOp Patterns -//===----------------------------------------------------------------------===// - -template -struct IToFPPattern final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(IToFPOp op, - PatternRewriter &rewriter) const override { - FailureOr narrowestWidth = - calculateBitsRequired(op.getIn(), Extension); - if (failed(narrowestWidth)) - return failure(); - - FailureOr narrowTy = - this->getNarrowType(*narrowestWidth, op.getIn().getType()); - if (failed(narrowTy)) - return failure(); - - Value newIn = rewriter.createOrFold(op.getLoc(), *narrowTy, - op.getIn()); - rewriter.replaceOpWithNewOp(op, op.getType(), newIn); - return success(); - } -}; -using SIToFPPattern = IToFPPattern; -using UIToFPPattern = IToFPPattern; - -//===----------------------------------------------------------------------===// -// Index Cast Patterns -//===----------------------------------------------------------------------===// - -// These rely on the `ValueBounds` interface for index values. For example, we -// can often statically tell index value bounds of loop induction variables. - -template -struct IndexCastPattern final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(CastOp op, - PatternRewriter &rewriter) const override { - Value in = op.getIn(); - // We only support scalar index -> integer casts. - if (!isa(in.getType())) - return failure(); - - // Check the lower bound in both the signed and unsigned cast case. We - // conservatively assume that even unsigned casts may be performed on - // negative indices. - FailureOr lb = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::LB, in); - if (failed(lb)) - return failure(); - - FailureOr ub = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, in, - /*stopCondition=*/nullptr, /*closedUB=*/true); - if (failed(ub)) - return failure(); - - assert(*lb <= *ub && "Invalid bounds"); - unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind); - unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind); - unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired); - - IntegerType resultTy = cast(op.getType()); - if (resultTy.getWidth() <= bitsRequired) - return failure(); - - FailureOr narrowTy = this->getNarrowType(bitsRequired, resultTy); - if (failed(narrowTy)) - return failure(); - - Value newCast = rewriter.create(op.getLoc(), *narrowTy, op.getIn()); - - if (Kind == ExtensionKind::Sign) - rewriter.replaceOpWithNewOp(op, resultTy, newCast); - else - rewriter.replaceOpWithNewOp(op, resultTy, newCast); - return success(); - } -}; -using IndexCastSIPattern = - IndexCastPattern; -using IndexCastUIPattern = - IndexCastPattern; - -//===----------------------------------------------------------------------===// -// Patterns to Commute Extension Ops -//===----------------------------------------------------------------------===// - -struct ExtensionOverBroadcast final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::BroadcastOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getSource().getDefiningOp()); - if (failed(ext)) - return failure(); - - VectorType origTy = op.getResultVectorType(); - VectorType newTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); - Value newBroadcast = - rewriter.create(op.getLoc(), newTy, ext->getIn()); - ext->recreateAndReplace(rewriter, op, newBroadcast); - return success(); - } -}; - -struct ExtensionOverExtract final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::ExtractOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getVector().getDefiningOp()); - if (failed(ext)) - return failure(); - - Value newExtract = rewriter.create( - op.getLoc(), ext->getIn(), op.getMixedPosition()); - ext->recreateAndReplace(rewriter, op, newExtract); - return success(); - } -}; - -struct ExtensionOverExtractElement final - : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::ExtractElementOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getVector().getDefiningOp()); - if (failed(ext)) - return failure(); - - Value newExtract = rewriter.create( - op.getLoc(), ext->getIn(), op.getPosition()); - ext->recreateAndReplace(rewriter, op, newExtract); - return success(); - } -}; - -struct ExtensionOverExtractStridedSlice final - : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getVector().getDefiningOp()); - if (failed(ext)) - return failure(); - - VectorType origTy = op.getType(); - VectorType extractTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); - Value newExtract = rewriter.create( - op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), - op.getStrides()); - ext->recreateAndReplace(rewriter, op, newExtract); - return success(); - } -}; - -/// Base pattern for `vector.insert` narrowing patterns. -template -struct ExtensionOverInsertionPattern : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - /// Derived classes must provide a function to create the matching insertion - /// op based on the original op and new arguments. - virtual InsertionOp createInsertionOp(PatternRewriter &rewriter, - InsertionOp origInsert, - Value narrowValue, - Value narrowDest) const = 0; - - LogicalResult matchAndRewrite(InsertionOp op, - PatternRewriter &rewriter) const final { - FailureOr ext = - ExtensionOp::from(op.getSource().getDefiningOp()); - if (failed(ext)) - return failure(); - - FailureOr newInsert = createNarrowInsert(op, rewriter, *ext); - if (failed(newInsert)) - return failure(); - ext->recreateAndReplace(rewriter, op, *newInsert); - return success(); - } - - FailureOr createNarrowInsert(InsertionOp op, - PatternRewriter &rewriter, - ExtensionOp insValue) const { - // Calculate the operand and result bitwidths. We can only apply narrowing - // when the inserted source value and destination vector require fewer bits - // than the result. Because the source and destination may have different - // bitwidths requirements, we have to find the common narrow bitwidth that - // is greater equal to the operand bitwidth requirements and still narrower - // than the result. - FailureOr origBitsRequired = calculateBitsRequired(op.getType()); - if (failed(origBitsRequired)) - return failure(); - - // TODO: We could relax this check by disregarding bitwidth requirements of - // elements that we know will be replaced by the insertion. - FailureOr destBitsRequired = - calculateBitsRequired(op.getDest(), insValue.getKind()); - if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) - return failure(); - - FailureOr insertedBitsRequired = - calculateBitsRequired(insValue.getIn(), insValue.getKind()); - if (failed(insertedBitsRequired) || - *insertedBitsRequired >= *origBitsRequired) - return failure(); - - // Find a narrower element type that satisfies the bitwidth requirements of - // both the source and the destination values. - unsigned newInsertionBits = - std::max(*destBitsRequired, *insertedBitsRequired); - FailureOr newVecTy = - this->getNarrowType(newInsertionBits, op.getType()); - if (failed(newVecTy) || *newVecTy == op.getType()) - return failure(); - - FailureOr newInsertedValueTy = - this->getNarrowType(newInsertionBits, insValue.getType()); - if (failed(newInsertedValueTy)) - return failure(); - - Location loc = op.getLoc(); - Value narrowValue = rewriter.createOrFold( - loc, *newInsertedValueTy, insValue.getResult()); - Value narrowDest = - rewriter.createOrFold(loc, *newVecTy, op.getDest()); - return createInsertionOp(rewriter, op, narrowValue, narrowDest); - } -}; - -struct ExtensionOverInsert final - : ExtensionOverInsertionPattern { - using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; - - vector::InsertOp createInsertionOp(PatternRewriter &rewriter, - vector::InsertOp origInsert, - Value narrowValue, - Value narrowDest) const override { - return rewriter.create(origInsert.getLoc(), narrowValue, - narrowDest, - origInsert.getMixedPosition()); - } -}; - -struct ExtensionOverInsertElement final - : ExtensionOverInsertionPattern { - using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; - - vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter, - vector::InsertElementOp origInsert, - Value narrowValue, - Value narrowDest) const override { - return rewriter.create( - origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); - } -}; - -struct ExtensionOverInsertStridedSlice final - : ExtensionOverInsertionPattern { - using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; - - vector::InsertStridedSliceOp - createInsertionOp(PatternRewriter &rewriter, - vector::InsertStridedSliceOp origInsert, Value narrowValue, - Value narrowDest) const override { - return rewriter.create( - origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(), - origInsert.getStrides()); - } -}; - -struct ExtensionOverShapeCast final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getSource().getDefiningOp()); - if (failed(ext)) - return failure(); - - VectorType origTy = op.getResultVectorType(); - VectorType newTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); - Value newCast = - rewriter.create(op.getLoc(), newTy, ext->getIn()); - ext->recreateAndReplace(rewriter, op, newCast); - return success(); - } -}; - -struct ExtensionOverTranspose final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getVector().getDefiningOp()); - if (failed(ext)) - return failure(); - - VectorType origTy = op.getResultVectorType(); - VectorType newTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); - Value newTranspose = rewriter.create( - op.getLoc(), newTy, ext->getIn(), op.getPermutation()); - ext->recreateAndReplace(rewriter, op, newTranspose); - return success(); - } -}; - -struct ExtensionOverFlatTranspose final - : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::FlatTransposeOp op, - PatternRewriter &rewriter) const override { - FailureOr ext = - ExtensionOp::from(op.getMatrix().getDefiningOp()); - if (failed(ext)) - return failure(); - - VectorType origTy = op.getType(); - VectorType newTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); - Value newTranspose = rewriter.create( - op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), - op.getColumnsAttr()); - ext->recreateAndReplace(rewriter, op, newTranspose); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Pass Definitions -//===----------------------------------------------------------------------===// - -struct ArithIntNarrowingPass final - : impl::ArithIntNarrowingBase { - using ArithIntNarrowingBase::ArithIntNarrowingBase; - - void runOnOperation() override { - if (bitwidthsSupported.empty() || - llvm::is_contained(bitwidthsSupported, 0)) { - // Invalid pass options. - return signalPassFailure(); - } - - Operation *op = getOperation(); - MLIRContext *ctx = op->getContext(); - RewritePatternSet patterns(ctx); - populateArithIntNarrowingPatterns( - patterns, ArithIntNarrowingOptions{ - llvm::to_vector_of(bitwidthsSupported)}); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Public API -//===----------------------------------------------------------------------===// - -void populateArithIntNarrowingPatterns( - RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { - // Add commute patterns with a higher benefit. This is to expose more - // optimization opportunities to narrowing patterns. - patterns.add( - patterns.getContext(), options, PatternBenefit(2)); - - patterns.add(patterns.getContext(), options); -} - -} // namespace mlir::arith diff --git a/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir b/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir deleted file mode 100644 index 0e34108973b4c..0000000000000 --- a/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: not mlir-opt %s --arith-int-narrowing --mlir-print-ir-after-failure 2>&1 \ -// RUN: | FileCheck %s - -// RUN: not mlir-opt %s --arith-int-narrowing="int-bitwidths-supported=0" \ -// RUN: --mlir-print-ir-after-failure 2>&1 | FileCheck %s - -// Make sure we do not crash on invalid pass options. - -// CHECK: IR Dump After ArithIntNarrowing Failed (arith-int-narrowing) -// CHECK-LABEL: func.func @addi_extsi_i8 -func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.addi %a, %b : i32 - return %r : i32 -} diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir deleted file mode 100644 index 153c0a8576262..0000000000000 --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ /dev/null @@ -1,997 +0,0 @@ -// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \ -// RUN: --verify-diagnostics %s | FileCheck %s - -//===----------------------------------------------------------------------===// -// arith.addi -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @addi_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.addi %a, %b : i32 - return %r : i32 -} - -// CHECK-LABEL: func.func @addi_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.addi %a, %b : i32 - return %r : i32 -} - -// arith.addi produces one more bit of result than the operand bitwidth. -// -// CHECK-LABEL: func.func @addi_extsi_i24 -// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i24 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { - %a = arith.extsi %lhs : i16 to i32 - %b = arith.extsi %rhs : i16 to i32 - %r = arith.addi %a, %b : i32 - return %r : i32 -} - -// This case should not get optimized because of mixed extensions. -// -// CHECK-LABEL: func.func @addi_mixed_ext_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[ADD]] : i32 -func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.addi %a, %b : i32 - return %r : i32 -} - -// This case should not get optimized because we cannot reduce the bitwidth -// below i16, given the pass options set. -// -// CHECK-LABEL: func.func @addi_extsi_i16 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16 -// CHECK-NEXT: return %[[ADD]] : i16 -func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 { - %a = arith.extsi %lhs : i8 to i16 - %b = arith.extsi %rhs : i8 to i16 - %r = arith.addi %a, %b : i16 - return %r : i16 -} - -// CHECK-LABEL: func.func @addi_extsi_3xi8_cst -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16> -// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32> -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16> -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[CST]] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { - %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32> - %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32> - %r = arith.addi %a, %cst : vector<3xi32> - return %r : vector<3xi32> -} - -//===----------------------------------------------------------------------===// -// arith.subi -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @subi_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.subi %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.subi` ops with sign-extended -// arguments. -// -// CHECK-LABEL: func.func @subi_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[SUB]] : i32 -func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.subi %a, %b : i32 - return %r : i32 -} - -// This case should not get optimized because of mixed extensions. -// -// CHECK-LABEL: func.func @subi_mixed_ext_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[ADD]] : i32 -func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.subi %a, %b : i32 - return %r : i32 -} - -// arith.subi produces one more bit of result than the operand bitwidth. -// -// CHECK-LABEL: func.func @subi_extsi_i24 -// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 -// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i24 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { - %a = arith.extsi %lhs : i16 to i32 - %b = arith.extsi %rhs : i16 to i32 - %r = arith.subi %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.muli -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @muli_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.muli %a, %b : i32 - return %r : i32 -} - -// CHECK-LABEL: func.func @muli_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.muli %a, %b : i32 - return %r : i32 -} - -// We do not expect this case to be optimized because given n-bit operands, -// arith.muli produces 2n bits of result. -// -// CHECK-LABEL: func.func @muli_extsi_i32 -// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32 -// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32 -// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 { - %a = arith.extsi %lhs : i16 to i32 - %b = arith.extsi %rhs : i16 to i32 - %r = arith.muli %a, %b : i32 - return %r : i32 -} - -// This case should not get optimized because of mixed extensions. -// -// CHECK-LABEL: func.func @muli_mixed_ext_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[MUL]] : i32 -func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.muli %a, %b : i32 - return %r : i32 -} - -// CHECK-LABEL: func.func @muli_extsi_3xi8_cst -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16> -// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32> -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16> -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[CST]] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> { - %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32> - %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32> - %r = arith.muli %a, %cst : vector<3xi32> - return %r : vector<3xi32> -} - -//===----------------------------------------------------------------------===// -// arith.divsi -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @divsi_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 -// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i16 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.divsi %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.divsi` ops with sign-extended -// arguments. -// -// CHECK-LABEL: func.func @divsi_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[SUB]] : i32 -func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.divsi %a, %b : i32 - return %r : i32 -} - -// arith.divsi produces one more bit of result than the operand bitwidth. -// -// CHECK-LABEL: func.func @divsi_extsi_i24 -// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 -// CHECK-NEXT: %[[ADD:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i24 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 { - %a = arith.extsi %lhs : i16 to i32 - %b = arith.extsi %rhs : i16 to i32 - %r = arith.divsi %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.divui -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @divui_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[ARG0]], %[[ARG1]] : i8 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[SUB]] : i8 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.divui %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.divui` ops with zero-extended -// arguments. -// -// CHECK-LABEL: func.func @divui_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[SUB]] : i32 -func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.divui %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.*itofp -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @sitofp_extsi_i16 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @sitofp_extsi_i16(%a: i16) -> f16 { - %b = arith.extsi %a : i16 to i32 - %f = arith.sitofp %b : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @sitofp_extsi_vector_i16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : vector<3xi16> to vector<3xf16> -// CHECK-NEXT: return %[[RET]] : vector<3xf16> -func.func @sitofp_extsi_vector_i16(%a: vector<3xi16>) -> vector<3xf16> { - %b = arith.extsi %a : vector<3xi16> to vector<3xi32> - %f = arith.sitofp %b : vector<3xi32> to vector<3xf16> - return %f : vector<3xf16> -} - -// CHECK-LABEL: func.func @sitofp_extsi_tensor_i16 -// CHECK-SAME: (%[[ARG:.+]]: tensor<3x?xi16>) -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : tensor<3x?xi16> to tensor<3x?xf16> -// CHECK-NEXT: return %[[RET]] : tensor<3x?xf16> -func.func @sitofp_extsi_tensor_i16(%a: tensor<3x?xi16>) -> tensor<3x?xf16> { - %b = arith.extsi %a : tensor<3x?xi16> to tensor<3x?xi32> - %f = arith.sitofp %b : tensor<3x?xi32> to tensor<3x?xf16> - return %f : tensor<3x?xf16> -} - -// Narrowing to i64 is not enabled in pass options. -// -// CHECK-LABEL: func.func @sitofp_extsi_i64 -// CHECK-SAME: (%[[ARG:.+]]: i64) -// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i64 to i128 -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXT]] : i128 to f32 -// CHECK-NEXT: return %[[RET]] : f32 -func.func @sitofp_extsi_i64(%a: i64) -> f32 { - %b = arith.extsi %a : i64 to i128 - %f = arith.sitofp %b : i128 to f32 - return %f : f32 -} - -// CHECK-LABEL: func.func @uitofp_extui_i16 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[ARG]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @uitofp_extui_i16(%a: i16) -> f16 { - %b = arith.extui %a : i16 to i32 - %f = arith.uitofp %b : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @sitofp_extsi_extsi_i8 -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i8 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @sitofp_extsi_extsi_i8(%a: i8) -> f16 { - %b = arith.extsi %a : i8 to i16 - %c = arith.extsi %b : i16 to i32 - %f = arith.sitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @uitofp_extui_extui_i8 -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[ARG]] : i8 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @uitofp_extui_extui_i8(%a: i8) -> f16 { - %b = arith.extui %a : i8 to i16 - %c = arith.extui %b : i16 to i32 - %f = arith.uitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @uitofp_extsi_extui_i8 -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i8 to i16 -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXT]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @uitofp_extsi_extui_i8(%a: i8) -> f16 { - %b = arith.extsi %a : i8 to i16 - %c = arith.extui %b : i16 to i32 - %f = arith.uitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @uitofp_trunci_extui_i8 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[TR:.+]] = arith.trunci %[[ARG]] : i16 to i8 -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[TR]] : i8 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @uitofp_trunci_extui_i8(%a: i16) -> f16 { - %b = arith.trunci %a : i16 to i8 - %c = arith.extui %b : i8 to i32 - %f = arith.uitofp %c : i32 to f16 - return %f : f16 -} - -// This should not be folded because arith.extui changes the signed -// range of the number. For example: -// extsi -1 : i16 to i32 ==> -1 -// extui -1 : i16 to i32 ==> U16_MAX -// -/// CHECK-LABEL: func.func @sitofp_extui_i16 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[EXT:.+]] = arith.extui %[[ARG]] : i16 to i32 -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXT]] : i32 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @sitofp_extui_i16(%a: i16) -> f16 { - %b = arith.extui %a : i16 to i32 - %f = arith.sitofp %b : i32 to f16 - return %f : f16 -} - -// This should not be folded because arith.extsi changes the unsigned -// range of the number. For example: -// extsi -1 : i16 to i32 ==> U32_MAX -// extui -1 : i16 to i32 ==> U16_MAX -// -// CHECK-LABEL: func.func @uitofp_extsi_i16 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG]] : i16 to i32 -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXT]] : i32 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @uitofp_extsi_i16(%a: i16) -> f16 { - %b = arith.extsi %a : i16 to i32 - %f = arith.uitofp %b : i32 to f16 - return %f : f16 -} - -//===----------------------------------------------------------------------===// -// arith.maxsi -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @maxsi_extsi_i8 -// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) -// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[LHS]], %[[RHS]] : i8 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MAX]] : i8 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @maxsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.maxsi %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.maxsi` ops with sign-extended -// arguments. -// -// CHECK-LABEL: func.func @maxsi_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[MAX]] : i32 -func.func @maxsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.maxsi %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.maxui -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @maxui_extui_i8 -// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) -// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[LHS]], %[[RHS]] : i8 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MAX]] : i8 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @maxui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.maxui %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.maxsi` ops with zero-extended -// arguments. -// -// CHECK-LABEL: func.func @maxui_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[MAX]] : i32 -func.func @maxui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.maxui %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.minsi -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @minsi_extsi_i8 -// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) -// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[LHS]], %[[RHS]] : i8 -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[min]] : i8 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @minsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.minsi %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.minsi` ops with sign-extended -// arguments. -// -// CHECK-LABEL: func.func @minsi_extui_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[min]] : i32 -func.func @minsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.minsi %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// arith.minui -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @minui_extui_i8 -// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8) -// CHECK-NEXT: %[[min:.+]] = arith.minui %[[LHS]], %[[RHS]] : i8 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[min]] : i8 to i32 -// CHECK-NEXT: return %[[RET]] : i32 -func.func @minui_extui_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extui %lhs : i8 to i32 - %b = arith.extui %rhs : i8 to i32 - %r = arith.minui %a, %b : i32 - return %r : i32 -} - -// This patterns should only apply to `arith.minsi` ops with zero-extended -// arguments. -// -// CHECK-LABEL: func.func @minui_extsi_i8 -// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) -// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 -// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[min:.+]] = arith.minui %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[min]] : i32 -func.func @minui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 { - %a = arith.extsi %lhs : i8 to i32 - %b = arith.extsi %rhs : i8 to i32 - %r = arith.minui %a, %b : i32 - return %r : i32 -} - -//===----------------------------------------------------------------------===// -// Commute Extension over Vector Ops -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func.func @extsi_over_extract_3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : i16 from vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 { - %b = arith.extsi %a : vector<3xi16> to vector<3xi32> - %c = vector.extract %b[1] : i32 from vector<3xi32> - %f = arith.sitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @extui_over_extract_3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : i16 from vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 { - %b = arith.extui %a : vector<3xi16> to vector<3xi32> - %c = vector.extract %b[1] : i32 from vector<3xi32> - %f = arith.uitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 { - %b = arith.extsi %a : vector<3xi16> to vector<3xi32> - %c = vector.extractelement %b[%pos : i32] : vector<3xi32> - %f = arith.sitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @extui_over_extractelement_3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16 -// CHECK-NEXT: return %[[RET]] : f16 -func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 { - %b = arith.extui %a : vector<3xi16> to vector<3xi32> - %c = vector.extractelement %b[%pos : i32] : vector<3xi32> - %f = arith.uitofp %c : i32 to f16 - return %f : f16 -} - -// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32> -// CHECK-NEXT: return %[[RET]] : vector<2xi32> -func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> { - %b = arith.extsi %a : vector<3xi16> to vector<3xi32> - %c = vector.extract_strided_slice %b - {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> - return %c : vector<2xi32> -} - -// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32> -// CHECK-NEXT: return %[[RET]] : vector<2xi32> -func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> { - %b = arith.extui %a : vector<3xi16> to vector<3xi32> - %c = vector.extract_strided_slice %b - {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32> - return %c : vector<2xi32> -} - -// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d -// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32> -// CHECK-NEXT: return %[[RET]] : vector<1x2xi32> -func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> { - %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> - %c = vector.extract_strided_slice %b - {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> - return %c : vector<1x2xi32> -} - -// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d -// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) -// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32> -// CHECK-NEXT: return %[[RET]] : vector<1x2xi32> -func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> { - %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32> - %c = vector.extract_strided_slice %b - {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> - return %c : vector<1x2xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_3xi16 -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> { - %c = arith.extsi %a : vector<3xi16> to vector<3xi32> - %d = arith.extsi %b : i16 to i32 - %e = vector.insert %d, %c [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insert_3xi16 -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16) -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> { - %c = arith.extui %a : vector<3xi16> to vector<3xi32> - %d = arith.extui %b : i16 to i32 - %e = vector.insert %d, %c [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_0 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<0> : vector<3xi16> -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i16 into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insert_3xi16_cst_0(%a: i16) -> vector<3xi32> { - %cst = arith.constant dense<0> : vector<3xi32> - %d = arith.extsi %a : i16 to i32 - %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_3xi8_cst -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, -128]> : vector<3xi8> -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi8> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> { - %cst = arith.constant dense<[-1, 127, -128]> : vector<3xi32> - %d = arith.extsi %a : i8 to i32 - %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insert_3xi8_cst -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 127, -1]> : vector<3xi8> -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi8> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> { - %cst = arith.constant dense<[1, 127, 255]> : vector<3xi32> - %d = arith.extui %a : i8 to i32 - %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_i16 -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32 -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> { - %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32> - %d = arith.extsi %a : i8 to i32 - %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insert_3xi16_cst_i16 -// CHECK-SAME: (%[[ARG:.+]]: i8) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32 -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 -// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> { - %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32> - %d = arith.extui %a : i8 to i32 - %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16 -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> { - %c = arith.extsi %a : vector<3xi16> to vector<3xi32> - %d = arith.extsi %b : i16 to i32 - %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insertelement_3xi16 -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> { - %c = arith.extui %a : vector<3xi16> to vector<3xi32> - %d = arith.extui %b : i16 to i32 - %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16 -// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32 -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 -// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> { - %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32> - %d = arith.extsi %a : i8 to i32 - %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16 -// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32) -// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32 -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 -// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> { - %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32> - %d = arith.extui %a : i8 to i32 - %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>) -// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]] -// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> { - %c = arith.extsi %a : vector<3xi16> to vector<3xi32> - %d = arith.extsi %b : vector<2xi16> to vector<2xi32> - %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d -// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>) -// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]] -// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> { - %c = arith.extui %a : vector<3xi16> to vector<3xi32> - %d = arith.extui %b : vector<2xi16> to vector<2xi32> - %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32> - return %e : vector<3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d -// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>) -// CHECK-NEXT: %[[CST:.+]] = arith.constant -// CHECK-SAME{LITERAL}: dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32> -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16> -// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]] -// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32> -// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> -func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> { - %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32> - %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32> - %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> - return %e : vector<2x3xi32> -} - -// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d -// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>) -// CHECK-NEXT: %[[CST:.+]] = arith.constant -// CHECK-SAME{LITERAL}: dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16> -// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32> -// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16> -// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]] -// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32> -// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> -func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> { - %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32> - %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32> - %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> - return %e : vector<2x3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16 -// CHECK-SAME: (%[[ARG:.+]]: i16) -// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32> -// CHECK-NEXT: return %[[RET]] : vector<3xi32> -func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> { - %b = arith.extsi %a : i16 to i32 - %r = vector.broadcast %b : i32 to vector<3xi32> - return %r : vector<3xi32> -} - -// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) -// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32> -// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> -func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> { - %b = arith.extui %a : vector<3xi16> to vector<3xi32> - %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32> - return %r : vector<2x3xi32> -} - -// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) -// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32> -// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> -func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { - %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> - %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32> - return %r : vector<3x2xi32> -} - -// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) -// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32> -// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> -func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { - %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> - %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32> - return %r : vector<2x3x5xi32> -} - -// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) -// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32> -// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> -func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { - %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> - %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32> - return %r : vector<3x2xi32> -} - -// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) -// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32> -// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> -func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { - %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> - %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32> - return %r : vector<2x3x5xi32> -} - -// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) -// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32> -// CHECK-NEXT: return %[[RET]] : vector<16xi32> -func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { - %b = arith.extsi %a : vector<16xi16> to vector<16xi32> - %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32> - return %r : vector<16xi32> -} - -// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16 -// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) -// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16> -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32> -// CHECK-NEXT: return %[[RET]] : vector<16xi32> -func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { - %b = arith.extui %a : vector<16xi16> to vector<16xi32> - %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32> - return %r : vector<16xi32> -} diff --git a/mlir/test/Dialect/Linalg/int-narrowing.mlir b/mlir/test/Dialect/Linalg/int-narrowing.mlir deleted file mode 100644 index 8063d504597a3..0000000000000 --- a/mlir/test/Dialect/Linalg/int-narrowing.mlir +++ /dev/null @@ -1,147 +0,0 @@ -// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \ -// RUN: --verify-diagnostics %s | FileCheck %s - -// Check that we can calculate `linalg.index` value bounds and use them to -// optimize index casts. - -//===----------------------------------------------------------------------===// -// arith.index_cast -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @linalg_indexcast_dim_0_i8 -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i8 -// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i8 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcast_dim_0_i8(%arg0: tensor) -> tensor<128xf16> { - %init = tensor.empty() : tensor<128xf16> - %res = linalg.generic { - indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } - ins(%arg0 : tensor) - outs(%init : tensor<128xf16>) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 0 : index - %int = arith.index_cast %idx : index to i64 - %fp = arith.sitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor<128xf16> - - return %res : tensor<128xf16> -} - -// CHECK-LABEL: func @linalg_indexcast_dim_1_i16 -// CHECK: %[[IDX:.+]] = linalg.index 1 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i16 -// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i16 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcast_dim_1_i16(%arg0: tensor, %arg1: tensor) -> tensor { - %res = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } - ins(%arg0 : tensor) - outs(%arg1 : tensor) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 1 : index - %int = arith.index_cast %idx : index to i64 - %fp = arith.sitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor - - return %res : tensor -} - -// CHECK-LABEL: func @linalg_indexcast_dynamic_dim_i64 -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i64 -// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i64 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcast_dynamic_dim_i64(%arg0: tensor, %arg1: tensor) -> tensor { - %res = linalg.generic { - indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } - ins(%arg0 : tensor) - outs(%arg1 : tensor) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 0 : index - %int = arith.index_cast %idx : index to i64 - %fp = arith.sitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor - - return %res : tensor -} - -//===----------------------------------------------------------------------===// -// arith.index_castui -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @linalg_indexcastui_dim_0_i8 -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i8 -// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i8 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcastui_dim_0_i8(%arg0: tensor) -> tensor<256xf16> { - %init = tensor.empty() : tensor<256xf16> - %res = linalg.generic { - indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } - ins(%arg0 : tensor) - outs(%init : tensor<256xf16>) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 0 : index - %int = arith.index_castui %idx : index to i64 - %fp = arith.uitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor<256xf16> - - return %res : tensor<256xf16> -} - -// CHECK-LABEL: func @linalg_indexcastui_dim_1_i16 -// CHECK: %[[IDX:.+]] = linalg.index 1 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i16 -// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i16 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcastui_dim_1_i16(%arg0: tensor, %arg1: tensor) -> tensor { - %res = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } - ins(%arg0 : tensor) - outs(%arg1 : tensor) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 1 : index - %int = arith.index_castui %idx : index to i64 - %fp = arith.uitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor - - return %res : tensor -} - -// CHECK-LABEL: func @linalg_indexcastui_dynamic_dim_i64 -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i64 -// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i64 to f16 -// CHECK-NEXT: linalg.yield %[[FP]] : f16 -func.func @linalg_indexcastui_dynamic_dim_i64(%arg0: tensor, %arg1: tensor) -> tensor { - %res = linalg.generic { - indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } - ins(%arg0 : tensor) - outs(%arg1 : tensor) { - ^bb0(%in: f16, %out: f16): - %idx = linalg.index 0 : index - %int = arith.index_castui %idx : index to i64 - %fp = arith.uitofp %int : i64 to f16 - linalg.yield %fp : f16 - } -> tensor - - return %res : tensor -} From b4275530a289255ee9147e679384654813d2e642 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 22:37:26 +0200 Subject: [PATCH 07/20] nits --- .../Transforms/IntRangeOptimizations.cpp | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 8c651076df2e5..e026632d0201d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -218,9 +218,10 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { if (!type) { type = val.getType(); continue; - } else if (type != val.getType()) { - return nullptr; } + + if (type != val.getType()) + return nullptr; } } @@ -301,13 +302,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { auto dstInt = cast(dstType); if (dstInt.getWidth() < srcInt.getWidth()) { return builder.create(loc, dstType, src); - } else { - return builder.create(loc, dstType, src); } + return builder.create(loc, dstType, src); } -struct NarrowElementwise final - : public OpTraitRewritePattern { +struct NarrowElementwise final : OpTraitRewritePattern { NarrowElementwise(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpTraitRewritePattern(context), solver(s), @@ -316,7 +315,6 @@ struct NarrowElementwise final using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - std::optional range = getOperandsRange(solver, op->getResults()); if (!range) @@ -370,8 +368,8 @@ struct NarrowElementwise final SmallVector targetBitwidths; }; -struct NarrowCmpi final : public OpRewritePattern { - NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s, +struct NarrowCmpI final : OpRewritePattern { + NarrowCmpI(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s, ArrayRef target) : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) { } @@ -421,8 +419,8 @@ struct NarrowCmpi final : public OpRewritePattern { SmallVector targetBitwidths; }; -struct IntRangeOptimizationsPass - : public arith::impl::ArithIntRangeOptsBase { +struct IntRangeOptimizationsPass final + : arith::impl::ArithIntRangeOptsBase { void runOnOperation() override { Operation *op = getOperation(); @@ -446,8 +444,8 @@ struct IntRangeOptimizationsPass } }; -struct IntRangeNarrowingPass - : public arith::impl::ArithIntRangeNarrowingBase { +struct IntRangeNarrowingPass final + : arith::impl::ArithIntRangeNarrowingBase { using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; void runOnOperation() override { @@ -482,9 +480,9 @@ void mlir::arith::populateIntRangeOptimizationsPatterns( void mlir::arith::populateIntRangeNarrowingPatterns( RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported) { - // Cmpi uses args ranges instead of results, run it with higher benefit, + // CmpI uses args ranges instead of results, run it with higher benefit, // as its argumens can be potentially replaced. - patterns.add(patterns.getContext(), /*benefit*/ 10, solver, + patterns.add(patterns.getContext(), /*benefit*/ 10, solver, bitwidthsSupported); patterns.add(patterns.getContext(), solver, From 1f8358b7933776aa9f98774b3b08348c3af4e408 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 16 Oct 2024 23:01:37 +0200 Subject: [PATCH 08/20] comments --- .../Transforms/IntRangeOptimizations.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index e026632d0201d..43ea7ee9a85b0 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -195,7 +195,8 @@ struct DeleteTrivialRem : public OpRewritePattern { DataFlowSolver &solver; }; -static Type checkArithType(Type type, unsigned targetBitwidth) { +/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`. +static Type checkIntType(Type type, unsigned targetBitwidth) { type = getElementTypeOrSelf(type); if (isa(type)) return type; @@ -207,6 +208,9 @@ static Type checkArithType(Type type, unsigned targetBitwidth) { return nullptr; } +/// Check if op have same type for all operands and results and this type +/// is suitable for truncation. +/// Retuns args type or empty. static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return nullptr; @@ -225,13 +229,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { } } - return checkArithType(type, targetBitwidth); + return checkIntType(type, targetBitwidth); } +/// Return union of all operands values ranges. static std::optional getOperandsRange(DataFlowSolver &solver, - ValueRange results) { + ValueRange operands) { std::optional ret; - for (Value value : results) { + for (Value value : operands) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) @@ -249,6 +254,8 @@ static std::optional getOperandsRange(DataFlowSolver &solver, return ret; } +/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, +/// return shaped type as well. static Type getTargetType(Type srcType, unsigned targetBitwidth) { auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); if (auto shaped = dyn_cast(srcType)) @@ -258,6 +265,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) { return dstType; } +/// Check privided `range` is inside `smin, smax, umin, umax` bounds. static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, APInt umin, APInt umax) { auto sge = [](APInt val1, APInt val2) -> bool { @@ -300,9 +308,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { auto srcInt = cast(srcType); auto dstInt = cast(dstType); - if (dstInt.getWidth() < srcInt.getWidth()) { + if (dstInt.getWidth() < srcInt.getWidth()) return builder.create(loc, dstType, src); - } + return builder.create(loc, dstType, src); } @@ -385,7 +393,7 @@ struct NarrowCmpI final : OpRewritePattern { return failure(); for (unsigned targetBitwidth : targetBitwidths) { - Type srcType = checkArithType(lhs.getType(), targetBitwidth); + Type srcType = checkIntType(lhs.getType(), targetBitwidth); if (!srcType) continue; From f7b5485ee2cc888716ba1593fdbb52eee185aea8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 1 Nov 2024 23:43:16 +0100 Subject: [PATCH 09/20] add vector support --- .../Transforms/IntRangeOptimizations.cpp | 20 +++++++------ .../Dialect/Arith/int-range-narrowing.mlir | 28 +++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 43ea7ee9a85b0..45e870eac180d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -197,11 +197,11 @@ struct DeleteTrivialRem : public OpRewritePattern { /// Check if `type` is index or integer type with `getWidth() > targetBitwidth`. static Type checkIntType(Type type, unsigned targetBitwidth) { - type = getElementTypeOrSelf(type); - if (isa(type)) + Type elemType = getElementTypeOrSelf(type); + if (isa(elemType)) return type; - if (auto intType = dyn_cast(type)) + if (auto intType = dyn_cast(elemType)) if (intType.getWidth() > targetBitwidth) return type; @@ -298,16 +298,20 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { Type srcType = src.getType(); - assert(srcType.isIntOrIndex() && "Invalid src type"); - assert(dstType.isIntOrIndex() && "Invalid dst type"); + assert(isa(srcType) == isa(dstType) && + "Mixing vector and non-vector types"); + Type srcElemType = getElementTypeOrSelf(srcType); + Type dstElemType = getElementTypeOrSelf(dstType); + assert(srcElemType.isIntOrIndex() && "Invalid src type"); + assert(dstElemType.isIntOrIndex() && "Invalid dst type"); if (srcType == dstType) return src; - if (isa(srcType) || isa(dstType)) + if (isa(srcElemType) || isa(dstElemType)) return builder.create(loc, dstType, src); - auto srcInt = cast(srcType); - auto dstInt = cast(dstType); + auto srcInt = cast(srcElemType); + auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) return builder.create(loc, dstType, src); diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index cd0a4c449913e..1378fb1c3c98c 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -30,6 +30,20 @@ func.func @test_addi() -> index { return %2 : index } +// CHECK-LABEL: func @test_addi_vec +// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex> +// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex> +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8> +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8> +// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8> +// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex> +// CHECK: return %[[RES_CASTED]] : vector<4xindex> +func.func @test_addi_vec() -> vector<4xindex> { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex> + %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> + %2 = arith.addi %0, %1 : vector<4xindex> + return %2 : vector<4xindex> +} // CHECK-LABEL: func @test_addi_i64 // CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64 @@ -60,6 +74,20 @@ func.func @test_cmpi() -> i1 { return %2 : i1 } +// CHECK-LABEL: func @test_cmpi_vec +// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex> +// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex> +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8> +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8> +// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8> +// CHECK: return %[[RES]] : vector<4xi1> +func.func @test_cmpi_vec() -> vector<4xi1> { + %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex> + %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex> + %2 = arith.cmpi slt, %0, %1 : vector<4xindex> + return %2 : vector<4xi1> +} + //===----------------------------------------------------------------------===// // arith.addi //===----------------------------------------------------------------------===// From 976f7c72678b9e53a1d65d1e47bbf20f945b15d4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 12:23:40 +0100 Subject: [PATCH 10/20] remove benefit --- .../Arith/Transforms/IntRangeOptimizations.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 45e870eac180d..80199ce9b618d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -381,10 +381,8 @@ struct NarrowElementwise final : OpTraitRewritePattern { }; struct NarrowCmpI final : OpRewritePattern { - NarrowCmpI(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s, - ArrayRef target) - : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) { - } + NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef target) + : OpRewritePattern(context), solver(s), targetBitwidths(target) {} LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { @@ -492,13 +490,8 @@ void mlir::arith::populateIntRangeOptimizationsPatterns( void mlir::arith::populateIntRangeNarrowingPatterns( RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported) { - // CmpI uses args ranges instead of results, run it with higher benefit, - // as its argumens can be potentially replaced. - patterns.add(patterns.getContext(), /*benefit*/ 10, solver, - bitwidthsSupported); - - patterns.add(patterns.getContext(), solver, - bitwidthsSupported); + patterns.add(patterns.getContext(), solver, + bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { From 98a0da4b85095928f9caaaaee803e353b90723c6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 12:34:11 +0100 Subject: [PATCH 11/20] style fixes --- .../Arith/Transforms/IntRangeOptimizations.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 80199ce9b618d..692b40830704e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -245,11 +245,7 @@ static std::optional getOperandsRange(DataFlowSolver &solver, const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); - if (!ret) { - ret = inferredRange; - } else { - ret = ret->rangeUnion(inferredRange); - } + ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange); } return ret; } @@ -265,7 +261,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) { return dstType; } -/// Check privided `range` is inside `smin, smax, umin, umax` bounds. +/// Check provided `range` is inside `smin, smax, umin, umax` bounds. static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, APInt umin, APInt umax) { auto sge = [](APInt val1, APInt val2) -> bool { @@ -321,8 +317,7 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { struct NarrowElementwise final : OpTraitRewritePattern { NarrowElementwise(MLIRContext *context, DataFlowSolver &s, ArrayRef target) - : OpTraitRewritePattern(context), solver(s), - targetBitwidths(target) {} + : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, From cb1741f27161f8041fed847f7c9f405d6204dc0a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 12:37:08 +0100 Subject: [PATCH 12/20] llvm concat --- .../Arith/Transforms/IntRangeOptimizations.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 692b40830704e..d2e36f9ef8239 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -216,17 +216,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { return nullptr; Type type; - for (auto range : - {ValueRange(op->getOperands()), ValueRange(op->getResults())}) { - for (Value val : range) { - if (!type) { - type = val.getType(); - continue; - } - - if (type != val.getType()) - return nullptr; + for (Value val : llvm::concat(op->getOperands(), op->getResults())) { + if (!type) { + type = val.getType(); + continue; } + + if (type != val.getType()) + return nullptr; } return checkIntType(type, targetBitwidth); From 581760a562afb1ab4e20366bcf7663dbd58412dc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:10:26 +0100 Subject: [PATCH 13/20] checkIntType refac --- .../Transforms/IntRangeOptimizations.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index d2e36f9ef8239..79fdd16575e8f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -196,24 +196,23 @@ struct DeleteTrivialRem : public OpRewritePattern { }; /// Check if `type` is index or integer type with `getWidth() > targetBitwidth`. -static Type checkIntType(Type type, unsigned targetBitwidth) { +static bool checkIntType(Type type, unsigned targetBitwidth) { Type elemType = getElementTypeOrSelf(type); if (isa(elemType)) - return type; + return true; if (auto intType = dyn_cast(elemType)) if (intType.getWidth() > targetBitwidth) - return type; + return true; - return nullptr; + return false; } /// Check if op have same type for all operands and results and this type /// is suitable for truncation. -/// Retuns args type or empty. -static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { +static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) - return nullptr; + return false; Type type; for (Value val : llvm::concat(op->getOperands(), op->getResults())) { @@ -223,7 +222,7 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { } if (type != val.getType()) - return nullptr; + return false; } return checkIntType(type, targetBitwidth); @@ -325,10 +324,11 @@ struct NarrowElementwise final : OpTraitRewritePattern { return failure(); for (unsigned targetBitwidth : targetBitwidths) { - Type srcType = checkElementwiseOpType(op, targetBitwidth); - if (!srcType) + if (!checkElementwiseOpType(op, targetBitwidth)) continue; + Type srcType = op->getResult(0).getType(); + // We are truncating op args to the desired bitwidth before the op and // then extending op results back to the original width after. extui and // exti will produce different results for negative values, so limit @@ -387,8 +387,8 @@ struct NarrowCmpI final : OpRewritePattern { return failure(); for (unsigned targetBitwidth : targetBitwidths) { - Type srcType = checkIntType(lhs.getType(), targetBitwidth); - if (!srcType) + Type srcType = lhs.getType(); + if (!checkIntType(srcType, targetBitwidth)) continue; auto smin = APInt::getSignedMinValue(targetBitwidth); From fdcbb5f2cd083676cda570850b84870d1e8534f4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:19:08 +0100 Subject: [PATCH 14/20] add test --- .../Dialect/Arith/int-range-narrowing.mlir | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 1378fb1c3c98c..054d8bba9f7b1 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -88,6 +88,27 @@ func.func @test_cmpi_vec() -> vector<4xi1> { return %2 : vector<4xi1> } +// CHECK-LABEL: func @test_add_cmpi +// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index +// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index +// CHECK: %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index +// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8 +// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8 +// CHECK: %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8 +// CHECK: %[[RES1_CASTED1:.*]] = arith.index_castui %[[RES1]] : i8 to index +// CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8 +// CHECK: %[[RES1_CASTED2:.*]] = arith.index_castui %[[RES1_CASTED1]] : index to i8 +// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1_CASTED2]] : i8 +// CHECK: return %[[RES2]] : i1 +func.func @test_add_cmpi() -> i1 { + %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %3 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %4 = arith.addi %0, %1 : index + %5 = arith.cmpi slt, %3, %4 : index + return %5 : i1 +} + //===----------------------------------------------------------------------===// // arith.addi //===----------------------------------------------------------------------===// From c2c9e39b8b0d64b981041923a04162a9f9649692 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:30:50 +0100 Subject: [PATCH 15/20] add comment --- mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 79fdd16575e8f..f5b2d0d283796 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -465,6 +465,9 @@ struct IntRangeNarrowingPass final populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); GreedyRewriteConfig config; + // We specifically need bottom-up traversal as cmpi pattern needs range + // data, attched to it's original arguments. + config.useTopDownTraversal = false; config.listener = &listener; if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) From 1d88642fc8584e7b6fb92bdcca0eb3206c9d331b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:47:10 +0100 Subject: [PATCH 16/20] fold index cast chain --- .../Transforms/IntRangeOptimizations.cpp | 30 +++++++++++++++++++ .../Dialect/Arith/int-range-narrowing.mlir | 4 +-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index f5b2d0d283796..06406b0852c5d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -421,6 +421,35 @@ struct NarrowCmpI final : OpRewritePattern { SmallVector targetBitwidths; }; +/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg +/// This pattern assumes all passed `targetBitwidths` are not wider than index +/// type. +struct FoldIndexCastChain final : OpRewritePattern { + FoldIndexCastChain(MLIRContext *context, ArrayRef target) + : OpRewritePattern(context), targetBitwidths(target) {} + + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, + PatternRewriter &rewriter) const override { + auto srcOp = op.getIn().getDefiningOp(); + if (!srcOp) + return failure(); + + Value src = srcOp.getIn(); + if (src.getType() != op.getType()) + return failure(); + + auto intType = dyn_cast(op.getType()); + if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) + return failure(); + + rewriter.replaceOp(op, src); + return success(); + } + +private: + SmallVector targetBitwidths; +}; + struct IntRangeOptimizationsPass final : arith::impl::ArithIntRangeOptsBase { @@ -487,6 +516,7 @@ void mlir::arith::populateIntRangeNarrowingPatterns( ArrayRef bitwidthsSupported) { patterns.add(patterns.getContext(), solver, bitwidthsSupported); + patterns.add(patterns.getContext(), bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 054d8bba9f7b1..5ad89805a1b45 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -95,10 +95,8 @@ func.func @test_cmpi_vec() -> vector<4xi1> { // CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8 // CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8 // CHECK: %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8 -// CHECK: %[[RES1_CASTED1:.*]] = arith.index_castui %[[RES1]] : i8 to index // CHECK: %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8 -// CHECK: %[[RES1_CASTED2:.*]] = arith.index_castui %[[RES1_CASTED1]] : index to i8 -// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1_CASTED2]] : i8 +// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8 // CHECK: return %[[RES2]] : i1 func.func @test_add_cmpi() -> i1 { %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index From 08297538bee33e479dccc0f8902b893758e5a712 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:51:49 +0100 Subject: [PATCH 17/20] test --- .../Dialect/Arith/int-range-narrowing.mlir | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 5ad89805a1b45..8893f299177ce 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -107,6 +107,25 @@ func.func @test_add_cmpi() -> i1 { return %5 : i1 } +// CHECK-LABEL: func @test_add_cmpi_i64 +// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64 +// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64 +// CHECK: %[[C:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64 +// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8 +// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8 +// CHECK: %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8 +// CHECK: %[[C_CASTED:.*]] = arith.trunci %[[C]] : i64 to i8 +// CHECK: %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8 +// CHECK: return %[[RES2]] : i1 +func.func @test_add_cmpi_i64() -> i1 { + %0 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64 + %1 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64 + %3 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64 + %4 = arith.addi %0, %1 : i64 + %5 = arith.cmpi slt, %3, %4 : i64 + return %5 : i1 +} + //===----------------------------------------------------------------------===// // arith.addi //===----------------------------------------------------------------------===// From 324b8bdc84311a093a48b6fca8c80b178ce92167 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:51:56 +0100 Subject: [PATCH 18/20] fix comment --- mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 06406b0852c5d..450d3972bb99d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -495,7 +495,7 @@ struct IntRangeNarrowingPass final GreedyRewriteConfig config; // We specifically need bottom-up traversal as cmpi pattern needs range - // data, attched to it's original arguments. + // data, attached to its original argument values. config.useTopDownTraversal = false; config.listener = &listener; From 262e991e59ece054971a8b05a64ee6e27b6c20ae Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 3 Nov 2024 13:55:22 +0100 Subject: [PATCH 19/20] update pass desc --- mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 98f90d120fa1c..1d37314885d93 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -55,6 +55,9 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> { let description = [{ This pass runs integer range analysis and tries to narrow arith ops to the specified bitwidth based on its results. + + `bitwidthsSupported` assumed to be not wider than `index` type. + TODO: get index width from DLTI. }]; let options = [ From d1bc17c4034b5fc3fa2f636fb46869777c3811ab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 4 Nov 2024 16:30:17 +0100 Subject: [PATCH 20/20] LogicalResult --- .../Transforms/IntRangeOptimizations.cpp | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 450d3972bb99d..efc4db7e4c996 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -196,23 +196,24 @@ struct DeleteTrivialRem : public OpRewritePattern { }; /// Check if `type` is index or integer type with `getWidth() > targetBitwidth`. -static bool checkIntType(Type type, unsigned targetBitwidth) { +static LogicalResult checkIntType(Type type, unsigned targetBitwidth) { Type elemType = getElementTypeOrSelf(type); if (isa(elemType)) - return true; + return success(); if (auto intType = dyn_cast(elemType)) if (intType.getWidth() > targetBitwidth) - return true; + return success(); - return false; + return failure(); } /// Check if op have same type for all operands and results and this type /// is suitable for truncation. -static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { +static LogicalResult checkElementwiseOpType(Operation *op, + unsigned targetBitwidth) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) - return false; + return failure(); Type type; for (Value val : llvm::concat(op->getOperands(), op->getResults())) { @@ -222,7 +223,7 @@ static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) { } if (type != val.getType()) - return false; + return failure(); } return checkIntType(type, targetBitwidth); @@ -258,8 +259,8 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) { } /// Check provided `range` is inside `smin, smax, umin, umax` bounds. -static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, - APInt umin, APInt umax) { +static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin, + APInt smax, APInt umin, APInt umax) { auto sge = [](APInt val1, APInt val2) -> bool { unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); val1 = val1.sext(width); @@ -284,8 +285,8 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax, val2 = val2.zext(width); return val1.ule(val2); }; - return sge(range.smin(), smin) && sle(range.smax(), smax) && - uge(range.umin(), umin) && ule(range.umax(), umax); + return success(sge(range.smin(), smin) && sle(range.smax(), smax) && + uge(range.umin(), umin) && ule(range.umax(), umax)); } static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { @@ -324,7 +325,7 @@ struct NarrowElementwise final : OpTraitRewritePattern { return failure(); for (unsigned targetBitwidth : targetBitwidths) { - if (!checkElementwiseOpType(op, targetBitwidth)) + if (failed(checkElementwiseOpType(op, targetBitwidth))) continue; Type srcType = op->getResult(0).getType(); @@ -337,7 +338,7 @@ struct NarrowElementwise final : OpTraitRewritePattern { auto smax = APInt::getSignedMaxValue(targetBitwidth); auto umin = APInt::getMinValue(targetBitwidth); auto umax = APInt::getMaxValue(targetBitwidth); - if (!checkRange(*range, smin, smax, umin, umax)) + if (failed(checkRange(*range, smin, smax, umin, umax))) continue; Type targetType = getTargetType(srcType, targetBitwidth); @@ -388,14 +389,14 @@ struct NarrowCmpI final : OpRewritePattern { for (unsigned targetBitwidth : targetBitwidths) { Type srcType = lhs.getType(); - if (!checkIntType(srcType, targetBitwidth)) + if (failed(checkIntType(srcType, targetBitwidth))) continue; auto smin = APInt::getSignedMinValue(targetBitwidth); auto smax = APInt::getSignedMaxValue(targetBitwidth); auto umin = APInt::getMinValue(targetBitwidth); auto umax = APInt::getMaxValue(targetBitwidth); - if (!checkRange(*range, smin, smax, umin, umax)) + if (failed(checkRange(*range, smin, smax, umin, umax))) continue; Type targetType = getTargetType(srcType, targetBitwidth);