Skip to content

Commit 81f4a37

Browse files
committed
[Sema][CodeGen] Support __builtin_<op>_overflow with __intcap
Morello LLVM has downstream support for this, but it's both incomplete (see https://git.morello-project.org/morello/llvm-project/-/issues/80) and incorrect with regards to provenance (in that it takes a naive type-based approach rather than considering the cheri_no_provenance attribute, meaning it differs from the binary operators in provenance semantics). This is a from-scratch implementation that aims to not have the same shortcomings. Fixes #720
1 parent 9947f48 commit 81f4a37

File tree

4 files changed

+5663
-13
lines changed

4 files changed

+5663
-13
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 117 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/ADT/StringExtras.h"
3737
#include "llvm/Analysis/AssumptionCache.h"
3838
#include "llvm/Analysis/ValueTracking.h"
39+
#include "llvm/IR/Constants.h"
3940
#include "llvm/IR/DataLayout.h"
4041
#include "llvm/IR/InlineAsm.h"
4142
#include "llvm/IR/IntrinsicInst.h"
@@ -771,9 +772,7 @@ static WidthAndSignedness
771772
getIntegerWidthAndSignedness(const clang::ASTContext &context,
772773
const clang::QualType Type) {
773774
assert(Type->isIntegerType() && "Given type is not an integer.");
774-
unsigned Width = Type->isBooleanType() ? 1
775-
: Type->isBitIntType() ? context.getIntWidth(Type)
776-
: context.getTypeInfo(Type).Width;
775+
unsigned Width = context.getIntWidth(Type);
777776
bool Signed = Type->isSignedIntegerType();
778777
return {Width, Signed};
779778
}
@@ -1996,14 +1995,40 @@ static RValue EmitCheckedUnsignedMultiplySignedResult(
19961995
CodeGenFunction &CGF, const clang::Expr *Op1, WidthAndSignedness Op1Info,
19971996
const clang::Expr *Op2, WidthAndSignedness Op2Info,
19981997
const clang::Expr *ResultArg, QualType ResultQTy,
1999-
WidthAndSignedness ResultInfo) {
1998+
WidthAndSignedness ResultInfo, SourceLocation Loc) {
20001999
assert(isSpecialUnsignedMultiplySignedResult(
20012000
Builtin::BI__builtin_mul_overflow, Op1Info, Op2Info, ResultInfo) &&
20022001
"Cannot specialize this multiply");
20032002

2003+
clang::QualType Op1QTy = Op1->getType();
2004+
clang::QualType Op2QTy = Op2->getType();
2005+
bool Op1IsCap = Op1QTy->isCHERICapabilityType(CGF.getContext());
2006+
bool Op2IsCap = Op2QTy->isCHERICapabilityType(CGF.getContext());
2007+
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGF.getContext());
2008+
20042009
llvm::Value *V1 = CGF.EmitScalarExpr(Op1);
20052010
llvm::Value *V2 = CGF.EmitScalarExpr(Op2);
20062011

2012+
llvm::Value *ProvenanceCap = nullptr;
2013+
if (ResultIsCap) {
2014+
bool Op1NoProvenance =
2015+
!Op1IsCap || Op1QTy->hasAttr(attr::CHERINoProvenance);
2016+
bool Op2NoProvenance =
2017+
!Op2IsCap || Op2QTy->hasAttr(attr::CHERINoProvenance);
2018+
if (Op1NoProvenance && Op2NoProvenance)
2019+
ProvenanceCap = llvm::ConstantPointerNull::get(CGF.Int8CheriCapTy);
2020+
else if (Op1NoProvenance)
2021+
ProvenanceCap = V2;
2022+
else
2023+
ProvenanceCap = V1;
2024+
}
2025+
2026+
if (Op1IsCap)
2027+
V1 = CGF.getCapabilityIntegerValue(V1);
2028+
2029+
if (Op2IsCap)
2030+
V2 = CGF.getCapabilityIntegerValue(V2);
2031+
20072032
llvm::Value *HasOverflow;
20082033
llvm::Value *Result = EmitOverflowIntrinsic(
20092034
CGF, llvm::Intrinsic::umul_with_overflow, V1, V2, HasOverflow);
@@ -2017,6 +2042,9 @@ static RValue EmitCheckedUnsignedMultiplySignedResult(
20172042
llvm::Value *IntMaxOverflow = CGF.Builder.CreateICmpUGT(Result, IntMaxValue);
20182043
HasOverflow = CGF.Builder.CreateOr(HasOverflow, IntMaxOverflow);
20192044

2045+
if (ResultIsCap)
2046+
Result = CGF.setCapabilityIntegerValue(ProvenanceCap, Result, Loc);
2047+
20202048
bool isVolatile =
20212049
ResultArg->getType()->getPointeeType().isVolatileQualified();
20222050
Address ResultPtr = CGF.EmitPointerWithAlignment(ResultArg);
@@ -2042,18 +2070,47 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
20422070
WidthAndSignedness Op1Info, const clang::Expr *Op2,
20432071
WidthAndSignedness Op2Info,
20442072
const clang::Expr *ResultArg, QualType ResultQTy,
2045-
WidthAndSignedness ResultInfo) {
2073+
WidthAndSignedness ResultInfo,
2074+
SourceLocation Loc) {
20462075
assert(isSpecialMixedSignMultiply(Builtin::BI__builtin_mul_overflow, Op1Info,
20472076
Op2Info, ResultInfo) &&
20482077
"Not a mixed-sign multipliction we can specialize");
20492078

2079+
QualType Op1QTy = Op1->getType();
2080+
QualType Op2QTy = Op2->getType();
2081+
bool Op1IsCap = Op1QTy->isCHERICapabilityType(CGF.getContext());
2082+
bool Op2IsCap = Op2QTy->isCHERICapabilityType(CGF.getContext());
2083+
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGF.getContext());
2084+
20502085
// Emit the signed and unsigned operands.
20512086
const clang::Expr *SignedOp = Op1Info.Signed ? Op1 : Op2;
20522087
const clang::Expr *UnsignedOp = Op1Info.Signed ? Op2 : Op1;
20532088
llvm::Value *Signed = CGF.EmitScalarExpr(SignedOp);
20542089
llvm::Value *Unsigned = CGF.EmitScalarExpr(UnsignedOp);
20552090
unsigned SignedOpWidth = Op1Info.Signed ? Op1Info.Width : Op2Info.Width;
20562091
unsigned UnsignedOpWidth = Op1Info.Signed ? Op2Info.Width : Op1Info.Width;
2092+
bool SignedIsCap = Op1Info.Signed ? Op1IsCap : Op2IsCap;
2093+
bool UnsignedIsCap = Op1Info.Signed ? Op2IsCap : Op1IsCap;
2094+
2095+
llvm::Value *ProvenanceCap = nullptr;
2096+
if (ResultIsCap) {
2097+
bool Op1NoProvenance =
2098+
!Op1IsCap || Op1QTy->hasAttr(attr::CHERINoProvenance);
2099+
bool Op2NoProvenance =
2100+
!Op2IsCap || Op2QTy->hasAttr(attr::CHERINoProvenance);
2101+
if (Op1NoProvenance && Op2NoProvenance)
2102+
ProvenanceCap = llvm::ConstantPointerNull::get(CGF.Int8CheriCapTy);
2103+
else if (Op1NoProvenance)
2104+
ProvenanceCap = Op1Info.Signed ? Unsigned : Signed;
2105+
else
2106+
ProvenanceCap = Op1Info.Signed ? Signed : Unsigned;
2107+
}
2108+
2109+
if (SignedIsCap)
2110+
Signed = CGF.getCapabilityIntegerValue(Signed);
2111+
2112+
if (UnsignedIsCap)
2113+
Unsigned = CGF.getCapabilityIntegerValue(Unsigned);
20572114

20582115
// One of the operands may be smaller than the other. If so, [s|z]ext it.
20592116
if (SignedOpWidth < UnsignedOpWidth)
@@ -2064,7 +2121,9 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
20642121
llvm::Type *OpTy = Signed->getType();
20652122
llvm::Value *Zero = llvm::Constant::getNullValue(OpTy);
20662123
Address ResultPtr = CGF.EmitPointerWithAlignment(ResultArg);
2067-
llvm::Type *ResTy = ResultPtr.getElementType();
2124+
llvm::Type *ResTy = ResultIsCap ? llvm::IntegerType::get(CGF.getLLVMContext(),
2125+
ResultInfo.Width)
2126+
: ResultPtr.getElementType();
20682127
unsigned OpWidth = std::max(Op1Info.Width, Op2Info.Width);
20692128

20702129
// Take the absolute value of the signed operand.
@@ -2103,8 +2162,7 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
21032162
IsNegative, CGF.Builder.CreateIsNotNull(UnsignedResult));
21042163
Overflow = CGF.Builder.CreateOr(UnsignedOverflow, Underflow);
21052164
if (ResultInfo.Width < OpWidth) {
2106-
auto IntMax =
2107-
llvm::APInt::getMaxValue(ResultInfo.Width).zext(OpWidth);
2165+
auto IntMax = llvm::APInt::getMaxValue(ResultInfo.Width).zext(OpWidth);
21082166
llvm::Value *TruncOverflow = CGF.Builder.CreateICmpUGT(
21092167
UnsignedResult, llvm::ConstantInt::get(OpTy, IntMax));
21102168
Overflow = CGF.Builder.CreateOr(Overflow, TruncOverflow);
@@ -2118,6 +2176,9 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
21182176
}
21192177
assert(Overflow && Result && "Missing overflow or result");
21202178

2179+
if (ResultIsCap)
2180+
Result = CGF.setCapabilityIntegerValue(ProvenanceCap, Result, Loc);
2181+
21212182
bool isVolatile =
21222183
ResultArg->getType()->getPointeeType().isVolatileQualified();
21232184
CGF.Builder.CreateStore(CGF.EmitToMemory(Result, ResultQTy), ResultPtr,
@@ -4636,13 +4697,18 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
46364697
const clang::Expr *RightArg = E->getArg(1);
46374698
const clang::Expr *ResultArg = E->getArg(2);
46384699

4700+
clang::QualType LeftQTy = LeftArg->getType();
4701+
clang::QualType RightQTy = RightArg->getType();
46394702
clang::QualType ResultQTy =
46404703
ResultArg->getType()->castAs<PointerType>()->getPointeeType();
46414704

4705+
bool LeftIsCap = LeftQTy->isCHERICapabilityType(CGM.getContext());
4706+
bool RightIsCap = RightQTy->isCHERICapabilityType(CGM.getContext());
4707+
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGM.getContext());
46424708
WidthAndSignedness LeftInfo =
4643-
getIntegerWidthAndSignedness(CGM.getContext(), LeftArg->getType());
4709+
getIntegerWidthAndSignedness(CGM.getContext(), LeftQTy);
46444710
WidthAndSignedness RightInfo =
4645-
getIntegerWidthAndSignedness(CGM.getContext(), RightArg->getType());
4711+
getIntegerWidthAndSignedness(CGM.getContext(), RightQTy);
46464712
WidthAndSignedness ResultInfo =
46474713
getIntegerWidthAndSignedness(CGM.getContext(), ResultQTy);
46484714

@@ -4651,37 +4717,44 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
46514717
if (isSpecialMixedSignMultiply(BuiltinID, LeftInfo, RightInfo, ResultInfo))
46524718
return EmitCheckedMixedSignMultiply(*this, LeftArg, LeftInfo, RightArg,
46534719
RightInfo, ResultArg, ResultQTy,
4654-
ResultInfo);
4720+
ResultInfo, E->getExprLoc());
46554721

46564722
if (isSpecialUnsignedMultiplySignedResult(BuiltinID, LeftInfo, RightInfo,
46574723
ResultInfo))
46584724
return EmitCheckedUnsignedMultiplySignedResult(
46594725
*this, LeftArg, LeftInfo, RightArg, RightInfo, ResultArg, ResultQTy,
4660-
ResultInfo);
4726+
ResultInfo, E->getExprLoc());
46614727

46624728
WidthAndSignedness EncompassingInfo =
46634729
EncompassingIntegerType({LeftInfo, RightInfo, ResultInfo});
46644730

46654731
llvm::Type *EncompassingLLVMTy =
46664732
llvm::IntegerType::get(CGM.getLLVMContext(), EncompassingInfo.Width);
46674733

4668-
llvm::Type *ResultLLVMTy = CGM.getTypes().ConvertType(ResultQTy);
4734+
llvm::Type *ResultLLVMTy =
4735+
ResultIsCap
4736+
? llvm::IntegerType::get(CGM.getLLVMContext(), ResultInfo.Width)
4737+
: CGM.getTypes().ConvertType(ResultQTy);
46694738

46704739
llvm::Intrinsic::ID IntrinsicId;
4740+
bool Commutative;
46714741
switch (BuiltinID) {
46724742
default:
46734743
llvm_unreachable("Unknown overflow builtin id.");
46744744
case Builtin::BI__builtin_add_overflow:
4745+
Commutative = true;
46754746
IntrinsicId = EncompassingInfo.Signed
46764747
? llvm::Intrinsic::sadd_with_overflow
46774748
: llvm::Intrinsic::uadd_with_overflow;
46784749
break;
46794750
case Builtin::BI__builtin_sub_overflow:
4751+
Commutative = false;
46804752
IntrinsicId = EncompassingInfo.Signed
46814753
? llvm::Intrinsic::ssub_with_overflow
46824754
: llvm::Intrinsic::usub_with_overflow;
46834755
break;
46844756
case Builtin::BI__builtin_mul_overflow:
4757+
Commutative = true;
46854758
IntrinsicId = EncompassingInfo.Signed
46864759
? llvm::Intrinsic::smul_with_overflow
46874760
: llvm::Intrinsic::umul_with_overflow;
@@ -4692,6 +4765,33 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
46924765
llvm::Value *Right = EmitScalarExpr(RightArg);
46934766
Address ResultPtr = EmitPointerWithAlignment(ResultArg);
46944767

4768+
llvm::Value *ProvenanceCap = nullptr;
4769+
if (ResultIsCap) {
4770+
if (!Commutative) {
4771+
if (LeftIsCap)
4772+
ProvenanceCap = Left;
4773+
else
4774+
ProvenanceCap = llvm::ConstantPointerNull::get(Int8CheriCapTy);
4775+
} else {
4776+
bool LeftNoProvenance =
4777+
!LeftIsCap || LeftQTy->hasAttr(attr::CHERINoProvenance);
4778+
bool RightNoProvenance =
4779+
!RightIsCap || RightQTy->hasAttr(attr::CHERINoProvenance);
4780+
if (LeftNoProvenance && RightNoProvenance)
4781+
ProvenanceCap = llvm::ConstantPointerNull::get(Int8CheriCapTy);
4782+
else if (LeftNoProvenance)
4783+
ProvenanceCap = Right;
4784+
else
4785+
ProvenanceCap = Left;
4786+
}
4787+
}
4788+
4789+
if (LeftIsCap)
4790+
Left = getCapabilityIntegerValue(Left);
4791+
4792+
if (RightIsCap)
4793+
Right = getCapabilityIntegerValue(Right);
4794+
46954795
// Extend each operand to the encompassing type.
46964796
Left = Builder.CreateIntCast(Left, EncompassingLLVMTy, LeftInfo.Signed);
46974797
Right = Builder.CreateIntCast(Right, EncompassingLLVMTy, RightInfo.Signed);
@@ -4716,6 +4816,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
47164816
Result = ResultTrunc;
47174817
}
47184818

4819+
if (ResultIsCap)
4820+
Result =
4821+
setCapabilityIntegerValue(ProvenanceCap, Result, E->getExprLoc());
4822+
47194823
// Finally, store the result using the pointer.
47204824
bool isVolatile =
47214825
ResultArg->getType()->getPointeeType().isVolatileQualified();

clang/lib/Sema/SemaChecking.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
#include "clang/AST/TypeLoc.h"
3838
#include "clang/AST/UnresolvedSet.h"
3939
#include "clang/Basic/AddressSpaces.h"
40+
#include "clang/Basic/Builtins.h"
4041
#include "clang/Basic/CharInfo.h"
4142
#include "clang/Basic/Diagnostic.h"
43+
#include "clang/Basic/DiagnosticFrontend.h"
4244
#include "clang/Basic/IdentifierTable.h"
4345
#include "clang/Basic/LLVM.h"
4446
#include "clang/Basic/LangOptions.h"
@@ -475,6 +477,18 @@ static bool SemaBuiltinOverflow(Sema &S, CallExpr *TheCall,
475477
}
476478
}
477479

480+
// ScalarExprEmitter::EmitSub's diagnostics aren't included here since
481+
// they're generally unhelpful, grouped under pedantic warnings, and would be
482+
// confusing without also taking into account the type of the result.
483+
if (BuiltinID != Builtin::BI__builtin_sub_overflow) {
484+
assert((BuiltinID == Builtin::BI__builtin_add_overflow ||
485+
BuiltinID == Builtin::BI__builtin_mul_overflow) &&
486+
"Unexpected overflow builtin");
487+
488+
S.DiagnoseAmbiguousProvenance(TheCall->getArg(0), TheCall->getArg(1),
489+
TheCall->getExprLoc(), false);
490+
}
491+
478492
return false;
479493
}
480494

0 commit comments

Comments
 (0)