Skip to content

Commit d88dbc3

Browse files
andykaylormikolaj-pirog
authored andcommitted
[CIR] Add support for exact dynamic casts (llvm#164007)
This adds support for handling exact dynamic casts when optimizations are enabled.
1 parent 706dc1b commit d88dbc3

File tree

5 files changed

+362
-2
lines changed

5 files changed

+362
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
690690
isUsed = true;
691691
}
692692

693+
mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
694+
cir::FuncOp callee,
695+
ArrayRef<mlir::Value> args) {
696+
// TODO(cir): set the calling convention to this runtime call.
697+
assert(!cir::MissingFeatures::opFuncCallingConv());
698+
699+
cir::CallOp call = builder.createCallOp(loc, callee, args);
700+
assert(call->getNumResults() <= 1 &&
701+
"runtime functions have at most 1 result");
702+
703+
if (call->getNumResults() == 0)
704+
return nullptr;
705+
706+
return call->getResult(0);
707+
}
708+
693709
void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
694710
clang::QualType argType) {
695711
assert(argType->isReferenceType() == e->isGLValue() &&

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,9 @@ class CIRGenFunction : public CIRGenTypeCache {
14611461

14621462
void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty);
14631463

1464+
mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee,
1465+
llvm::ArrayRef<mlir::Value> args = {});
1466+
14641467
/// Emit the computation of the specified expression of scalar type.
14651468
mlir::Value emitScalarExpr(const clang::Expr *e);
14661469

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,6 +1874,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) {
18741874
return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast");
18751875
}
18761876

1877+
static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) {
1878+
// TODO(cir): set the calling convention to the runtime function.
1879+
assert(!cir::MissingFeatures::opFuncCallingConv());
1880+
1881+
cgf.emitRuntimeCall(loc, getBadCastFn(cgf));
1882+
cir::UnreachableOp::create(cgf.getBuilder(), loc);
1883+
cgf.getBuilder().clearInsertionPoint();
1884+
}
1885+
18771886
// TODO(cir): This could be shared with classic codegen.
18781887
static CharUnits computeOffsetHint(ASTContext &astContext,
18791888
const CXXRecordDecl *src,
@@ -1959,6 +1968,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc,
19591968
return Address{ptr, src.getAlignment()};
19601969
}
19611970

1971+
static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi,
1972+
CIRGenFunction &cgf, mlir::Location loc,
1973+
QualType srcRecordTy,
1974+
QualType destRecordTy,
1975+
cir::PointerType destCIRTy,
1976+
bool isRefCast, Address src) {
1977+
// Find all the inheritance paths from SrcRecordTy to DestRecordTy.
1978+
const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl();
1979+
const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl();
1980+
CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
1981+
/*DetectVirtual=*/false);
1982+
(void)destDecl->isDerivedFrom(srcDecl, paths);
1983+
1984+
// Find an offset within `destDecl` where a `srcDecl` instance and its vptr
1985+
// might appear.
1986+
std::optional<CharUnits> offset;
1987+
for (const CXXBasePath &path : paths) {
1988+
// dynamic_cast only finds public inheritance paths.
1989+
if (path.Access != AS_public)
1990+
continue;
1991+
1992+
CharUnits pathOffset;
1993+
for (const CXXBasePathElement &pathElement : path) {
1994+
// Find the offset along this inheritance step.
1995+
const CXXRecordDecl *base =
1996+
pathElement.Base->getType()->getAsCXXRecordDecl();
1997+
if (pathElement.Base->isVirtual()) {
1998+
// For a virtual base class, we know that the derived class is exactly
1999+
// destDecl, so we can use the vbase offset from its layout.
2000+
const ASTRecordLayout &layout =
2001+
cgf.getContext().getASTRecordLayout(destDecl);
2002+
pathOffset = layout.getVBaseClassOffset(base);
2003+
} else {
2004+
const ASTRecordLayout &layout =
2005+
cgf.getContext().getASTRecordLayout(pathElement.Class);
2006+
pathOffset += layout.getBaseClassOffset(base);
2007+
}
2008+
}
2009+
2010+
if (!offset) {
2011+
offset = pathOffset;
2012+
} else if (offset != pathOffset) {
2013+
// base appears in at least two different places. Find the most-derived
2014+
// object and see if it's a DestDecl. Note that the most-derived object
2015+
// must be at least as aligned as this base class subobject, and must
2016+
// have a vptr at offset 0.
2017+
src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src);
2018+
srcDecl = destDecl;
2019+
offset = CharUnits::Zero();
2020+
break;
2021+
}
2022+
}
2023+
2024+
CIRGenBuilderTy &builder = cgf.getBuilder();
2025+
2026+
if (!offset) {
2027+
// If there are no public inheritance paths, the cast always fails.
2028+
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
2029+
if (isRefCast) {
2030+
mlir::Region *currentRegion = builder.getBlock()->getParent();
2031+
emitCallToBadCast(cgf, loc);
2032+
2033+
// The call to bad_cast will terminate the block. Create a new block to
2034+
// hold any follow up code.
2035+
builder.createBlock(currentRegion, currentRegion->end());
2036+
}
2037+
2038+
return nullPtrValue;
2039+
}
2040+
2041+
// Compare the vptr against the expected vptr for the destination type at
2042+
// this offset. Note that we do not know what type src points to in the case
2043+
// where the derived class multiply inherits from the base class so we can't
2044+
// use getVTablePtr, so we load the vptr directly instead.
2045+
2046+
mlir::Value expectedVPtr =
2047+
abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl);
2048+
2049+
// TODO(cir): handle address space here.
2050+
assert(!cir::MissingFeatures::addressSpace());
2051+
mlir::Type vptrTy = expectedVPtr.getType();
2052+
mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy);
2053+
Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy),
2054+
src.getAlignment());
2055+
mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr);
2056+
2057+
// TODO(cir): decorate SrcVPtr with TBAA info.
2058+
assert(!cir::MissingFeatures::opTBAA());
2059+
2060+
mlir::Value success =
2061+
builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr);
2062+
2063+
auto emitCastResult = [&] {
2064+
if (offset->isZero())
2065+
return builder.createBitcast(src.getPointer(), destCIRTy);
2066+
2067+
// TODO(cir): handle address space here.
2068+
assert(!cir::MissingFeatures::addressSpace());
2069+
mlir::Type u8PtrTy = builder.getUInt8PtrTy();
2070+
2071+
mlir::Value strideToApply =
2072+
builder.getConstInt(loc, builder.getUInt64Ty(), -offset->getQuantity());
2073+
mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy);
2074+
mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy,
2075+
srcU8Ptr, strideToApply);
2076+
return builder.createBitcast(resultU8Ptr, destCIRTy);
2077+
};
2078+
2079+
if (isRefCast) {
2080+
mlir::Value failed = builder.createNot(success);
2081+
cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false,
2082+
[&](mlir::OpBuilder &, mlir::Location) {
2083+
emitCallToBadCast(cgf, loc);
2084+
});
2085+
return emitCastResult();
2086+
}
2087+
2088+
return cir::TernaryOp::create(
2089+
builder, loc, success,
2090+
[&](mlir::OpBuilder &, mlir::Location) {
2091+
auto result = emitCastResult();
2092+
builder.createYield(loc, result);
2093+
},
2094+
[&](mlir::OpBuilder &, mlir::Location) {
2095+
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
2096+
builder.createYield(loc, nullPtrValue);
2097+
})
2098+
.getResult();
2099+
}
2100+
19622101
static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf,
19632102
mlir::Location loc,
19642103
QualType srcRecordTy,
@@ -2000,8 +2139,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf,
20002139
// if the dynamic type of the pointer is exactly the destination type.
20012140
if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
20022141
cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) {
2003-
cgm.errorNYI(loc, "emitExactDynamicCast");
2004-
return {};
2142+
CIRGenBuilderTy &builder = cgf.getBuilder();
2143+
// If this isn't a reference cast, check the pointer to see if it's null.
2144+
if (!isRefCast) {
2145+
mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer());
2146+
return cir::TernaryOp::create(
2147+
builder, loc, srcPtrIsNull,
2148+
[&](mlir::OpBuilder, mlir::Location) {
2149+
builder.createYield(
2150+
loc, builder.getNullPtr(destCIRTy, loc).getResult());
2151+
},
2152+
[&](mlir::OpBuilder &, mlir::Location) {
2153+
mlir::Value exactCast = emitExactDynamicCast(
2154+
*this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy,
2155+
isRefCast, src);
2156+
builder.createYield(loc, exactCast);
2157+
})
2158+
.getResult();
2159+
}
2160+
2161+
return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy,
2162+
destCIRTy, isRefCast, src);
20052163
}
20062164

20072165
cir::DynamicCastInfoAttr castInfo =

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
24112411
return mlir::success();
24122412
}
24132413

2414+
if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) {
2415+
// !cir.vptr is a special case, but it's just a pointer to LLVM.
2416+
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
2417+
/* isSigned=*/false);
2418+
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
2419+
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
2420+
return mlir::success();
2421+
}
2422+
24142423
if (mlir::isa<cir::FPTypeInterface>(type)) {
24152424
mlir::LLVM::FCmpPredicate kind =
24162425
convertCmpKindToFCmpPredicate(cmpOp.getKind());

0 commit comments

Comments
 (0)