Skip to content

Commit 5bf7e8a

Browse files
authored
[CIR] Upstream overflow builtins (#166643)
This implements the builtins that handle overflow. This fixes issue #163888
1 parent 13011fe commit 5bf7e8a

File tree

4 files changed

+735
-2
lines changed

4 files changed

+735
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,82 @@ def CIR_CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
16401640
let isLLVMLoweringRecursive = true;
16411641
}
16421642

1643+
//===----------------------------------------------------------------------===//
1644+
// BinOpOverflowOp
1645+
//===----------------------------------------------------------------------===//
1646+
1647+
def CIR_BinOpOverflowKind : CIR_I32EnumAttr<
1648+
"BinOpOverflowKind", "checked binary arithmetic operation kind", [
1649+
I32EnumAttrCase<"Add", 0, "add">,
1650+
I32EnumAttrCase<"Sub", 1, "sub">,
1651+
I32EnumAttrCase<"Mul", 2, "mul">
1652+
]>;
1653+
1654+
def CIR_BinOpOverflowOp : CIR_Op<"binop.overflow", [Pure, SameTypeOperands]> {
1655+
let summary = "Perform binary integral arithmetic with overflow checking";
1656+
let description = [{
1657+
`cir.binop.overflow` performs binary arithmetic operations with overflow
1658+
checking on integral operands.
1659+
1660+
The `kind` argument specifies the kind of arithmetic operation to perform.
1661+
It can be either `add`, `sub`, or `mul`. The `lhs` and `rhs` arguments
1662+
specify the input operands of the arithmetic operation. The types of `lhs`
1663+
and `rhs` must be the same.
1664+
1665+
`cir.binop.overflow` produces two SSA values. `result` is the result of the
1666+
arithmetic operation truncated to its specified type. `overflow` is a
1667+
boolean value indicating whether overflow happens during the operation.
1668+
1669+
The exact semantic of this operation is as follows:
1670+
1671+
- `lhs` and `rhs` are promoted to an imaginary integral type that has
1672+
infinite precision.
1673+
- The arithmetic operation is performed on the promoted operands.
1674+
- The infinite-precision result is truncated to the type of `result`. The
1675+
truncated result is assigned to `result`.
1676+
- If the truncated result is equal to the un-truncated result, `overflow`
1677+
is assigned to false. Otherwise, `overflow` is assigned to true.
1678+
}];
1679+
1680+
let arguments = (ins
1681+
CIR_BinOpOverflowKind:$kind,
1682+
CIR_IntType:$lhs,
1683+
CIR_IntType:$rhs
1684+
);
1685+
1686+
let results = (outs CIR_IntType:$result, CIR_BoolType:$overflow);
1687+
1688+
let assemblyFormat = [{
1689+
`(` $kind `,` $lhs `,` $rhs `)` `:` qualified(type($lhs)) `,`
1690+
`(` qualified(type($result)) `,` qualified(type($overflow)) `)`
1691+
attr-dict
1692+
}];
1693+
1694+
let builders = [
1695+
OpBuilder<(ins "cir::IntType":$resultTy,
1696+
"cir::BinOpOverflowKind":$kind,
1697+
"mlir::Value":$lhs,
1698+
"mlir::Value":$rhs), [{
1699+
auto overflowTy = cir::BoolType::get($_builder.getContext());
1700+
build($_builder, $_state, resultTy, overflowTy, kind, lhs, rhs);
1701+
}]>
1702+
];
1703+
1704+
let extraLLVMLoweringPatternDecl = [{
1705+
static std::string getLLVMIntrinName(cir::BinOpOverflowKind opKind,
1706+
bool isSigned, unsigned width);
1707+
1708+
struct EncompassedTypeInfo {
1709+
bool sign;
1710+
unsigned width;
1711+
};
1712+
1713+
static EncompassedTypeInfo computeEncompassedTypeWidth(cir::IntType operandTy,
1714+
cir::IntType resultTy);
1715+
}];
1716+
}
1717+
1718+
16431719
//===----------------------------------------------------------------------===//
16441720
// BinOp
16451721
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,45 @@ static RValue emitBuiltinBitOp(CIRGenFunction &cgf, const CallExpr *e,
5858
return RValue::get(result);
5959
}
6060

61+
namespace {
62+
struct WidthAndSignedness {
63+
unsigned width;
64+
bool isSigned;
65+
};
66+
} // namespace
67+
68+
static WidthAndSignedness
69+
getIntegerWidthAndSignedness(const clang::ASTContext &astContext,
70+
const clang::QualType type) {
71+
assert(type->isIntegerType() && "Given type is not an integer.");
72+
unsigned width = type->isBooleanType() ? 1
73+
: type->isBitIntType() ? astContext.getIntWidth(type)
74+
: astContext.getTypeInfo(type).Width;
75+
bool isSigned = type->isSignedIntegerType();
76+
return {width, isSigned};
77+
}
78+
79+
// Given one or more integer types, this function produces an integer type that
80+
// encompasses them: any value in one of the given types could be expressed in
81+
// the encompassing type.
82+
static struct WidthAndSignedness
83+
EncompassingIntegerType(ArrayRef<struct WidthAndSignedness> types) {
84+
assert(types.size() > 0 && "Empty list of types.");
85+
86+
// If any of the given types is signed, we must return a signed type.
87+
bool isSigned = llvm::any_of(types, [](const auto &t) { return t.isSigned; });
88+
89+
// The encompassing type must have a width greater than or equal to the width
90+
// of the specified types. Additionally, if the encompassing type is signed,
91+
// its width must be strictly greater than the width of any unsigned types
92+
// given.
93+
unsigned width = 0;
94+
for (const auto &type : types)
95+
width = std::max(width, type.width + (isSigned && !type.isSigned));
96+
97+
return {width, isSigned};
98+
}
99+
61100
RValue CIRGenFunction::emitRotate(const CallExpr *e, bool isRotateLeft) {
62101
mlir::Value input = emitScalarExpr(e->getArg(0));
63102
mlir::Value amount = emitScalarExpr(e->getArg(1));
@@ -899,9 +938,85 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID,
899938
case Builtin::BI__builtin_subc:
900939
case Builtin::BI__builtin_subcl:
901940
case Builtin::BI__builtin_subcll:
941+
return errorBuiltinNYI(*this, e, builtinID);
942+
902943
case Builtin::BI__builtin_add_overflow:
903944
case Builtin::BI__builtin_sub_overflow:
904-
case Builtin::BI__builtin_mul_overflow:
945+
case Builtin::BI__builtin_mul_overflow: {
946+
const clang::Expr *leftArg = e->getArg(0);
947+
const clang::Expr *rightArg = e->getArg(1);
948+
const clang::Expr *resultArg = e->getArg(2);
949+
950+
clang::QualType resultQTy =
951+
resultArg->getType()->castAs<clang::PointerType>()->getPointeeType();
952+
953+
WidthAndSignedness leftInfo =
954+
getIntegerWidthAndSignedness(cgm.getASTContext(), leftArg->getType());
955+
WidthAndSignedness rightInfo =
956+
getIntegerWidthAndSignedness(cgm.getASTContext(), rightArg->getType());
957+
WidthAndSignedness resultInfo =
958+
getIntegerWidthAndSignedness(cgm.getASTContext(), resultQTy);
959+
960+
// Note we compute the encompassing type with the consideration to the
961+
// result type, so later in LLVM lowering we don't get redundant integral
962+
// extension casts.
963+
WidthAndSignedness encompassingInfo =
964+
EncompassingIntegerType({leftInfo, rightInfo, resultInfo});
965+
966+
auto encompassingCIRTy = cir::IntType::get(
967+
&getMLIRContext(), encompassingInfo.width, encompassingInfo.isSigned);
968+
auto resultCIRTy = mlir::cast<cir::IntType>(cgm.convertType(resultQTy));
969+
970+
mlir::Value left = emitScalarExpr(leftArg);
971+
mlir::Value right = emitScalarExpr(rightArg);
972+
Address resultPtr = emitPointerWithAlignment(resultArg);
973+
974+
// Extend each operand to the encompassing type, if necessary.
975+
if (left.getType() != encompassingCIRTy)
976+
left =
977+
builder.createCast(cir::CastKind::integral, left, encompassingCIRTy);
978+
if (right.getType() != encompassingCIRTy)
979+
right =
980+
builder.createCast(cir::CastKind::integral, right, encompassingCIRTy);
981+
982+
// Perform the operation on the extended values.
983+
cir::BinOpOverflowKind opKind;
984+
switch (builtinID) {
985+
default:
986+
llvm_unreachable("Unknown overflow builtin id.");
987+
case Builtin::BI__builtin_add_overflow:
988+
opKind = cir::BinOpOverflowKind::Add;
989+
break;
990+
case Builtin::BI__builtin_sub_overflow:
991+
opKind = cir::BinOpOverflowKind::Sub;
992+
break;
993+
case Builtin::BI__builtin_mul_overflow:
994+
opKind = cir::BinOpOverflowKind::Mul;
995+
break;
996+
}
997+
998+
mlir::Location loc = getLoc(e->getSourceRange());
999+
auto arithOp = cir::BinOpOverflowOp::create(builder, loc, resultCIRTy,
1000+
opKind, left, right);
1001+
1002+
// Here is a slight difference from the original clang CodeGen:
1003+
// - In the original clang CodeGen, the checked arithmetic result is
1004+
// first computed as a value of the encompassing type, and then it is
1005+
// truncated to the actual result type with a second overflow checking.
1006+
// - In CIRGen, the checked arithmetic operation directly produce the
1007+
// checked arithmetic result in its expected type.
1008+
//
1009+
// So we don't need a truncation and a second overflow checking here.
1010+
1011+
// Finally, store the result using the pointer.
1012+
bool isVolatile =
1013+
resultArg->getType()->getPointeeType().isVolatileQualified();
1014+
builder.createStore(loc, emitToMemory(arithOp.getResult(), resultQTy),
1015+
resultPtr, isVolatile);
1016+
1017+
return RValue::get(arithOp.getOverflow());
1018+
}
1019+
9051020
case Builtin::BI__builtin_uadd_overflow:
9061021
case Builtin::BI__builtin_uaddl_overflow:
9071022
case Builtin::BI__builtin_uaddll_overflow:
@@ -919,7 +1034,61 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID,
9191034
case Builtin::BI__builtin_ssubll_overflow:
9201035
case Builtin::BI__builtin_smul_overflow:
9211036
case Builtin::BI__builtin_smull_overflow:
922-
case Builtin::BI__builtin_smulll_overflow:
1037+
case Builtin::BI__builtin_smulll_overflow: {
1038+
// Scalarize our inputs.
1039+
mlir::Value x = emitScalarExpr(e->getArg(0));
1040+
mlir::Value y = emitScalarExpr(e->getArg(1));
1041+
1042+
const clang::Expr *resultArg = e->getArg(2);
1043+
Address resultPtr = emitPointerWithAlignment(resultArg);
1044+
1045+
// Decide which of the arithmetic operation we are lowering to:
1046+
cir::BinOpOverflowKind arithKind;
1047+
switch (builtinID) {
1048+
default:
1049+
llvm_unreachable("Unknown overflow builtin id.");
1050+
case Builtin::BI__builtin_uadd_overflow:
1051+
case Builtin::BI__builtin_uaddl_overflow:
1052+
case Builtin::BI__builtin_uaddll_overflow:
1053+
case Builtin::BI__builtin_sadd_overflow:
1054+
case Builtin::BI__builtin_saddl_overflow:
1055+
case Builtin::BI__builtin_saddll_overflow:
1056+
arithKind = cir::BinOpOverflowKind::Add;
1057+
break;
1058+
case Builtin::BI__builtin_usub_overflow:
1059+
case Builtin::BI__builtin_usubl_overflow:
1060+
case Builtin::BI__builtin_usubll_overflow:
1061+
case Builtin::BI__builtin_ssub_overflow:
1062+
case Builtin::BI__builtin_ssubl_overflow:
1063+
case Builtin::BI__builtin_ssubll_overflow:
1064+
arithKind = cir::BinOpOverflowKind::Sub;
1065+
break;
1066+
case Builtin::BI__builtin_umul_overflow:
1067+
case Builtin::BI__builtin_umull_overflow:
1068+
case Builtin::BI__builtin_umulll_overflow:
1069+
case Builtin::BI__builtin_smul_overflow:
1070+
case Builtin::BI__builtin_smull_overflow:
1071+
case Builtin::BI__builtin_smulll_overflow:
1072+
arithKind = cir::BinOpOverflowKind::Mul;
1073+
break;
1074+
}
1075+
1076+
clang::QualType resultQTy =
1077+
resultArg->getType()->castAs<clang::PointerType>()->getPointeeType();
1078+
auto resultCIRTy = mlir::cast<cir::IntType>(cgm.convertType(resultQTy));
1079+
1080+
mlir::Location loc = getLoc(e->getSourceRange());
1081+
cir::BinOpOverflowOp arithOp = cir::BinOpOverflowOp::create(
1082+
builder, loc, resultCIRTy, arithKind, x, y);
1083+
1084+
bool isVolatile =
1085+
resultArg->getType()->getPointeeType().isVolatileQualified();
1086+
builder.createStore(loc, emitToMemory(arithOp.getResult(), resultQTy),
1087+
resultPtr, isVolatile);
1088+
1089+
return RValue::get(arithOp.getOverflow());
1090+
}
1091+
9231092
case Builtin::BIaddressof:
9241093
case Builtin::BI__addressof:
9251094
case Builtin::BI__builtin_addressof:

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2586,6 +2586,120 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
25862586
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
25872587
}
25882588

2589+
mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite(
2590+
cir::BinOpOverflowOp op, OpAdaptor adaptor,
2591+
mlir::ConversionPatternRewriter &rewriter) const {
2592+
mlir::Location loc = op.getLoc();
2593+
cir::BinOpOverflowKind arithKind = op.getKind();
2594+
cir::IntType operandTy = op.getLhs().getType();
2595+
cir::IntType resultTy = op.getResult().getType();
2596+
2597+
EncompassedTypeInfo encompassedTyInfo =
2598+
computeEncompassedTypeWidth(operandTy, resultTy);
2599+
mlir::IntegerType encompassedLLVMTy =
2600+
rewriter.getIntegerType(encompassedTyInfo.width);
2601+
2602+
mlir::Value lhs = adaptor.getLhs();
2603+
mlir::Value rhs = adaptor.getRhs();
2604+
if (operandTy.getWidth() < encompassedTyInfo.width) {
2605+
if (operandTy.isSigned()) {
2606+
lhs = mlir::LLVM::SExtOp::create(rewriter, loc, encompassedLLVMTy, lhs);
2607+
rhs = mlir::LLVM::SExtOp::create(rewriter, loc, encompassedLLVMTy, rhs);
2608+
} else {
2609+
lhs = mlir::LLVM::ZExtOp::create(rewriter, loc, encompassedLLVMTy, lhs);
2610+
rhs = mlir::LLVM::ZExtOp::create(rewriter, loc, encompassedLLVMTy, rhs);
2611+
}
2612+
}
2613+
2614+
std::string intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign,
2615+
encompassedTyInfo.width);
2616+
auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName);
2617+
2618+
mlir::IntegerType overflowLLVMTy = rewriter.getI1Type();
2619+
auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral(
2620+
rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy});
2621+
2622+
auto callLLVMIntrinOp = mlir::LLVM::CallIntrinsicOp::create(
2623+
rewriter, loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs});
2624+
mlir::Value intrinRet = callLLVMIntrinOp.getResult(0);
2625+
2626+
mlir::Value result = mlir::LLVM::ExtractValueOp::create(
2627+
rewriter, loc, intrinRet, ArrayRef<int64_t>{0})
2628+
.getResult();
2629+
mlir::Value overflow = mlir::LLVM::ExtractValueOp::create(
2630+
rewriter, loc, intrinRet, ArrayRef<int64_t>{1})
2631+
.getResult();
2632+
2633+
if (resultTy.getWidth() < encompassedTyInfo.width) {
2634+
mlir::Type resultLLVMTy = getTypeConverter()->convertType(resultTy);
2635+
auto truncResult =
2636+
mlir::LLVM::TruncOp::create(rewriter, loc, resultLLVMTy, result);
2637+
2638+
// Extend the truncated result back to the encompassing type to check for
2639+
// any overflows during the truncation.
2640+
mlir::Value truncResultExt;
2641+
if (resultTy.isSigned())
2642+
truncResultExt = mlir::LLVM::SExtOp::create(
2643+
rewriter, loc, encompassedLLVMTy, truncResult);
2644+
else
2645+
truncResultExt = mlir::LLVM::ZExtOp::create(
2646+
rewriter, loc, encompassedLLVMTy, truncResult);
2647+
auto truncOverflow = mlir::LLVM::ICmpOp::create(
2648+
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result);
2649+
2650+
result = truncResult;
2651+
overflow = mlir::LLVM::OrOp::create(rewriter, loc, overflow, truncOverflow);
2652+
}
2653+
2654+
mlir::Type boolLLVMTy =
2655+
getTypeConverter()->convertType(op.getOverflow().getType());
2656+
if (boolLLVMTy != rewriter.getI1Type())
2657+
overflow = mlir::LLVM::ZExtOp::create(rewriter, loc, boolLLVMTy, overflow);
2658+
2659+
rewriter.replaceOp(op, mlir::ValueRange{result, overflow});
2660+
2661+
return mlir::success();
2662+
}
2663+
2664+
std::string CIRToLLVMBinOpOverflowOpLowering::getLLVMIntrinName(
2665+
cir::BinOpOverflowKind opKind, bool isSigned, unsigned width) {
2666+
// The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}`
2667+
2668+
std::string name = "llvm.";
2669+
2670+
if (isSigned)
2671+
name.push_back('s');
2672+
else
2673+
name.push_back('u');
2674+
2675+
switch (opKind) {
2676+
case cir::BinOpOverflowKind::Add:
2677+
name.append("add.");
2678+
break;
2679+
case cir::BinOpOverflowKind::Sub:
2680+
name.append("sub.");
2681+
break;
2682+
case cir::BinOpOverflowKind::Mul:
2683+
name.append("mul.");
2684+
break;
2685+
}
2686+
2687+
name.append("with.overflow.i");
2688+
name.append(std::to_string(width));
2689+
2690+
return name;
2691+
}
2692+
2693+
CIRToLLVMBinOpOverflowOpLowering::EncompassedTypeInfo
2694+
CIRToLLVMBinOpOverflowOpLowering::computeEncompassedTypeWidth(
2695+
cir::IntType operandTy, cir::IntType resultTy) {
2696+
bool sign = operandTy.getIsSigned() || resultTy.getIsSigned();
2697+
unsigned width =
2698+
std::max(operandTy.getWidth() + (sign && operandTy.isUnsigned()),
2699+
resultTy.getWidth() + (sign && resultTy.isUnsigned()));
2700+
return {sign, width};
2701+
}
2702+
25892703
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
25902704
cir::ShiftOp op, OpAdaptor adaptor,
25912705
mlir::ConversionPatternRewriter &rewriter) const {

0 commit comments

Comments
 (0)