diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index efc4db7e4c996..b54a53f5ef70e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -46,6 +46,17 @@ static std::optional getMaybeConstantValue(DataFlowSolver &solver, return inferredRange.getConstantValue(); } +static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, + Value newVal) { + assert(oldVal.getType() == newVal.getType() && + "Can't copy integer ranges between different types"); + auto *oldState = solver.lookupState(oldVal); + if (!oldState) + return; + (void)solver.getOrCreateState(newVal)->join( + *oldState); +} + /// Patterned after SCCP static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, @@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, if (!constOp) return failure(); + copyIntegerRange(solver, value, constOp->getResult(0)); rewriter.replaceAllUsesWith(value, constOp->getResult(0)); return success(); } @@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern { DataFlowSolver &solver; }; -/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`. -static LogicalResult checkIntType(Type type, unsigned targetBitwidth) { - Type elemType = getElementTypeOrSelf(type); - if (isa(elemType)) - return success(); - - if (auto intType = dyn_cast(elemType)) - if (intType.getWidth() > targetBitwidth) - return success(); - - return failure(); -} - -/// Check if op have same type for all operands and results and this type -/// is suitable for truncation. -static LogicalResult checkElementwiseOpType(Operation *op, - unsigned targetBitwidth) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) - return failure(); - - Type type; - for (Value val : llvm::concat(op->getOperands(), op->getResults())) { - if (!type) { - type = val.getType(); - continue; - } - - if (type != val.getType()) - return failure(); - } - - return checkIntType(type, targetBitwidth); -} - -/// Return union of all operands values ranges. -static std::optional getOperandsRange(DataFlowSolver &solver, - ValueRange operands) { - std::optional ret; - for (Value value : operands) { +/// Gather ranges for all the values in `values`. Appends to the existing +/// vector. +static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, + SmallVectorImpl &ranges) { + for (Value val : values) { auto *maybeInferredRange = - solver.lookupState(value); + solver.lookupState(val); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) - return std::nullopt; + return failure(); const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); - - ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange); + ranges.push_back(inferredRange); } - return ret; + return success(); } /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, @@ -258,41 +235,59 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) { return dstType; } -/// Check provided `range` is inside `smin, smax, umin, umax` bounds. -static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin, - APInt smax, APInt umin, APInt umax) { - auto sge = [](APInt val1, APInt val2) -> bool { - unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); - val1 = val1.sext(width); - val2 = val2.sext(width); - return val1.sge(val2); - }; - auto sle = [](APInt val1, APInt val2) -> bool { - unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); - val1 = val1.sext(width); - val2 = val2.sext(width); - return val1.sle(val2); - }; - auto uge = [](APInt val1, APInt val2) -> bool { - unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); - val1 = val1.zext(width); - val2 = val2.zext(width); - return val1.uge(val2); - }; - auto ule = [](APInt val1, APInt val2) -> bool { - unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth()); - val1 = val1.zext(width); - val2 = val2.zext(width); - return val1.ule(val2); - }; - return success(sge(range.smin(), smin) && sle(range.smax(), smax) && - uge(range.umin(), umin) && ule(range.umax(), umax)); +namespace { +// Enum for tracking which type of truncation should be performed +// to narrow an operation, if any. +enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; +} // namespace + +/// If the values within `range` can be represented using only `width` bits, +/// return the kind of truncation needed to preserve that property. +/// +/// This check relies on the fact that the signed and unsigned ranges are both +/// always correct, but that one might be an approximation of the other, +/// so we want to use the correct truncation operation. +static CastKind checkTruncatability(const ConstantIntRanges &range, + unsigned targetWidth) { + unsigned srcWidth = range.smin().getBitWidth(); + if (srcWidth <= targetWidth) + return CastKind::None; + unsigned removedWidth = srcWidth - targetWidth; + // The sign bits need to extend into the sign bit of the target width. For + // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign + // bits. + bool canTruncateSigned = + range.smin().getNumSignBits() >= (removedWidth + 1) && + range.smax().getNumSignBits() >= (removedWidth + 1); + bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && + range.umax().countLeadingZeros() >= removedWidth; + if (canTruncateSigned && canTruncateUnsigned) + return CastKind::Both; + if (canTruncateSigned) + return CastKind::Signed; + if (canTruncateUnsigned) + return CastKind::Unsigned; + return CastKind::None; +} + +static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { + if (lhs == CastKind::None || rhs == CastKind::None) + return CastKind::None; + if (lhs == CastKind::Both) + return rhs; + if (rhs == CastKind::Both) + return lhs; + if (lhs == rhs) + return lhs; + return CastKind::None; } -static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { +static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, + CastKind castKind) { Type srcType = src.getType(); assert(isa(srcType) == isa(dstType) && "Mixing vector and non-vector types"); + assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); Type srcElemType = getElementTypeOrSelf(srcType); Type dstElemType = getElementTypeOrSelf(dstType); assert(srcElemType.isIntOrIndex() && "Invalid src type"); @@ -300,14 +295,19 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) { if (srcType == dstType) return src; - if (isa(srcElemType) || isa(dstElemType)) + if (isa(srcElemType) || isa(dstElemType)) { + if (castKind == CastKind::Signed) + return builder.create(loc, dstType, src); return builder.create(loc, dstType, src); + } auto srcInt = cast(srcElemType); auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) return builder.create(loc, dstType, src); + if (castKind == CastKind::Signed) + return builder.create(loc, dstType, src); return builder.create(loc, dstType, src); } @@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern { using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - std::optional range = - getOperandsRange(solver, op->getResults()); - if (!range) - return failure(); + if (op->getNumResults() == 0) + return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); + + SmallVector ranges; + if (failed(collectRanges(solver, op->getOperands(), ranges))) + return rewriter.notifyMatchFailure(op, "input without specified range"); + if (failed(collectRanges(solver, op->getResults(), ranges))) + return rewriter.notifyMatchFailure(op, "output without specified range"); + + Type srcType = op->getResult(0).getType(); + if (!llvm::all_equal(op->getResultTypes())) + return rewriter.notifyMatchFailure(op, "mismatched result types"); + if (op->getNumOperands() == 0 || + !llvm::all_of(op->getOperandTypes(), + [=](Type t) { return t == srcType; })) + return rewriter.notifyMatchFailure( + op, "no operands or operand types don't match result type"); for (unsigned targetBitwidth : targetBitwidths) { - if (failed(checkElementwiseOpType(op, targetBitwidth))) - continue; - - Type srcType = op->getResult(0).getType(); - - // We are truncating op args to the desired bitwidth before the op and - // then extending op results back to the original width after. extui and - // exti will produce different results for negative values, so limit - // signed range to non-negative values. - auto smin = APInt::getZero(targetBitwidth); - auto smax = APInt::getSignedMaxValue(targetBitwidth); - auto umin = APInt::getMinValue(targetBitwidth); - auto umax = APInt::getMaxValue(targetBitwidth); - if (failed(checkRange(*range, smin, smax, umin, umax))) + CastKind castKind = CastKind::Both; + for (const ConstantIntRanges &range : ranges) { + castKind = mergeCastKinds(castKind, + checkTruncatability(range, targetBitwidth)); + if (castKind == CastKind::None) + break; + } + if (castKind == CastKind::None) continue; - Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; Location loc = op->getLoc(); IRMapping mapping; - for (Value arg : op->getOperands()) { - Value newArg = doCast(rewriter, loc, arg, targetType); + for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { + CastKind argCastKind = castKind; + // When dealing with `index` values, preserve non-negativity in the + // index_casts since we can't recover this in unsigned when equivalent. + if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) + argCastKind = CastKind::Both; + Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); mapping.map(arg, newArg); } @@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern { } }); SmallVector newResults; - for (Value res : newOp->getResults()) - newResults.emplace_back(doCast(rewriter, loc, res, srcType)); + for (auto [newRes, oldRes] : + llvm::zip_equal(newOp->getResults(), op->getResults())) { + Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); + copyIntegerRange(solver, oldRes, castBack); + newResults.push_back(castBack); + } rewriter.replaceOp(op, newResults); return success(); @@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - std::optional range = - getOperandsRange(solver, {lhs, rhs}); - if (!range) + SmallVector ranges; + if (failed(collectRanges(solver, op.getOperands(), ranges))) return failure(); + const ConstantIntRanges &lhsRange = ranges[0]; + const ConstantIntRanges &rhsRange = ranges[1]; + Type srcType = lhs.getType(); for (unsigned targetBitwidth : targetBitwidths) { - Type srcType = lhs.getType(); - if (failed(checkIntType(srcType, targetBitwidth))) - continue; - - auto smin = APInt::getSignedMinValue(targetBitwidth); - auto smax = APInt::getSignedMaxValue(targetBitwidth); - auto umin = APInt::getMinValue(targetBitwidth); - auto umax = APInt::getMaxValue(targetBitwidth); - if (failed(checkRange(*range, smin, smax, umin, umax))) + CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); + CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); + CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); + // Note: this includes target width > src width. + if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); @@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern { Location loc = op->getLoc(); IRMapping mapping; - for (Value arg : op->getOperands()) { - Value newArg = doCast(rewriter, loc, arg, targetType); - mapping.map(arg, newArg); - } + Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); + Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); + mapping.map(lhs, lhsCast); + mapping.map(rhs, rhsCast); Operation *newOp = rewriter.clone(*op, mapping); + copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); rewriter.replaceOp(op, newOp->getResults()); return success(); } @@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern { /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg /// This pattern assumes all passed `targetBitwidths` are not wider than index /// type. -struct FoldIndexCastChain final : OpRewritePattern { +template +struct FoldIndexCastChain final : OpRewritePattern { FoldIndexCastChain(MLIRContext *context, ArrayRef target) - : OpRewritePattern(context), targetBitwidths(target) {} + : OpRewritePattern(context), targetBitwidths(target) {} - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, + LogicalResult matchAndRewrite(CastOp op, PatternRewriter &rewriter) const override { - auto srcOp = op.getIn().getDefiningOp(); + auto srcOp = op.getIn().template getDefiningOp(); if (!srcOp) - return failure(); + return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); Value src = srcOp.getIn(); if (src.getType() != op.getType()) - return failure(); + return rewriter.notifyMatchFailure(op, "outer types don't match"); + + if (!srcOp.getType().isIndex()) + return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); auto intType = dyn_cast(op.getType()); if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) @@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns( ArrayRef bitwidthsSupported) { patterns.add(patterns.getContext(), solver, bitwidthsSupported); - patterns.add(patterns.getContext(), bitwidthsSupported); + patterns.add, + FoldIndexCastChain>(patterns.getContext(), + bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index 8893f299177ce..e16db6293560e 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -4,9 +4,14 @@ // Some basic tests //===----------------------------------------------------------------------===// -// Do not truncate negative values +// Truncate possibly-negative values in a signed way // CHECK-LABEL: func @test_addi_neg -// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index +// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index +// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8 +// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8 +// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8 +// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index // CHECK: return %[[RES]] : index func.func @test_addi_neg() -> index { %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index @@ -146,14 +151,18 @@ func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { return %r : i32 } -// This case should not get optimized because of mixed extensions. +// This can be optimized to i16 since we're dealing in [-128, 127] + [0, 255], +// which is [-128, 382] // // CHECK-LABEL: func.func @addi_mixed_ext_i8 // CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) // CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 // CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[ADD]] : i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { %a = arith.extsi %lhs : i8 to i32 %b = arith.extui %rhs : i8 to i32 @@ -181,15 +190,15 @@ func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 { // arith.subi //===----------------------------------------------------------------------===// -// This patterns should only apply to `arith.subi` ops with sign-extended -// arguments. -// // CHECK-LABEL: func.func @subi_extui_i8 // CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) // CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 // CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[SUB]] : i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { %a = arith.extui %lhs : i8 to i32 %b = arith.extui %rhs : i8 to i32 @@ -197,14 +206,17 @@ func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 { return %r : i32 } -// This case should not get optimized because of mixed extensions. +// Despite the mixed sign and zero extensions, we can optimize here // // CHECK-LABEL: func.func @subi_mixed_ext_i8 // CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) // CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 // CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[ADD]] : i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { %a = arith.extsi %lhs : i8 to i32 %b = arith.extui %rhs : i8 to i32 @@ -216,15 +228,14 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { // arith.muli //===----------------------------------------------------------------------===// -// TODO: This should be optimized into i16 // CHECK-LABEL: func.func @muli_extui_i8 // CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) // CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32 // CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24 -// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i24 -// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i24 to i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32 // CHECK-NEXT: return %[[RET]] : i32 func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 { %a = arith.extui %lhs : i8 to i32 @@ -249,17 +260,90 @@ func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 { return %r : i32 } -// This case should not get optimized because of mixed extensions. +// The mixed extensions mean that we have [-128, 127] * [0, 255], which can +// be computed exactly in i16. // // CHECK-LABEL: func.func @muli_mixed_ext_i8 // CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8) // CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 // CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32 -// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 -// CHECK-NEXT: return %[[MUL]] : i32 +// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16 +// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16 +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32 +// CHECK-NEXT: return %[[RET]] : i32 func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { %a = arith.extsi %lhs : i8 to i32 %b = arith.extui %rhs : i8 to i32 %r = arith.muli %a, %b : i32 return %r : i32 } + +// Can't reduce width here since we need the extra bits +// CHECK-LABEL: func.func @i32_overflows_to_index +// CHECK-SAME: (%[[ARG0:.+]]: i32) +// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32 +// CHECK: %[[CAST:.+]] = arith.index_castui %[[CLAMPED]] : i32 to index +// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : index +// CHECK: return %[[MUL]] : index +func.func @i32_overflows_to_index(%arg0: i32) -> index { + %c0_i32 = arith.constant 0 : i32 + %c4 = arith.constant 4 : index + %clamped = arith.maxsi %arg0, %c0_i32 : i32 + %cast = arith.index_castui %clamped : i32 to index + %mul = arith.muli %cast, %c4 : index + return %mul : index +} + +// Can't reduce width here since we need the extra bits +// CHECK-LABEL: func.func @i32_overflows_to_i64 +// CHECK-SAME: (%[[ARG0:.+]]: i32) +// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32 +// CHECK: %[[CAST:.+]] = arith.extui %[[CLAMPED]] : i32 to i64 +// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : i64 +// CHECK: return %[[MUL]] : i64 +func.func @i32_overflows_to_i64(%arg0: i32) -> i64 { + %c0_i32 = arith.constant 0 : i32 + %c4_i64 = arith.constant 4 : i64 + %clamped = arith.maxsi %arg0, %c0_i32 : i32 + %cast = arith.extui %clamped : i32 to i64 + %mul = arith.muli %cast, %c4_i64 : i64 + return %mul : i64 +} + +// Motivating example for negative number support, added as a test case +// and simplified +// CHECK-LABEL: func.func @clamp_to_loop_bound_and_id() +// CHECK: %[[TID:.+]] = test.with_bounds +// CHECK-SAME: umax = 63 +// CHECK: %[[BOUND:.+]] = test.with_bounds +// CHECK-SAME: umax = 112 +// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}} +// CHECK-DAG: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8 +// CHECK-DAG: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8 +// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8 +// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8 +// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index +// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16 +// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16 +// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16 +// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16 +// CHECK: scf.if %[[V3]] +func.func @clamp_to_loop_bound_and_id() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + + %tid = test.with_bounds {smin = 0 : index, smax = 63 : index, umin = 0 : index, umax = 63 : index} : index + %bound = test.with_bounds {smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index} : index + scf.for %arg0 = %c16 to %bound step %c64 { + %0 = arith.subi %bound, %arg0 : index + %1 = arith.minsi %0, %c64 : index + %2 = arith.subi %1, %tid : index + %3 = arith.cmpi slt, %2, %c0 : index + scf.if %3 { + vector.print str "sideeffect" + } + } + return +}