Skip to content

Commit a8a6019

Browse files
committed
[CIR] Add support for exact dynamic casts
This adds support for handling exact dynamic casts when optimizations are enabled.
1 parent a99e32b commit a8a6019

File tree

5 files changed

+302
-2
lines changed

5 files changed

+302
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
690690
isUsed = true;
691691
}
692692

693+
mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
694+
cir::FuncOp callee,
695+
ArrayRef<mlir::Value> args) {
696+
// TODO(cir): set the calling convention to this runtime call.
697+
assert(!cir::MissingFeatures::opFuncCallingConv());
698+
699+
cir::CallOp call = builder.createCallOp(loc, callee, args);
700+
assert(call->getNumResults() <= 1 &&
701+
"runtime functions have at most 1 result");
702+
703+
if (call->getNumResults() == 0)
704+
return nullptr;
705+
706+
return call->getResult(0);
707+
}
708+
693709
void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
694710
clang::QualType argType) {
695711
assert(argType->isReferenceType() == e->isGLValue() &&

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache {
13801380

13811381
void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty);
13821382

1383+
mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee,
1384+
llvm::ArrayRef<mlir::Value> args = {});
1385+
13831386
/// Emit the computation of the specified expression of scalar type.
13841387
mlir::Value emitScalarExpr(const clang::Expr *e);
13851388

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) {
18691869
return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast");
18701870
}
18711871

1872+
static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) {
1873+
// TODO(cir): set the calling convention to the runtime function.
1874+
assert(!cir::MissingFeatures::opFuncCallingConv());
1875+
1876+
cgf.emitRuntimeCall(loc, getBadCastFn(cgf));
1877+
cir::UnreachableOp::create(cgf.getBuilder(), loc);
1878+
cgf.getBuilder().clearInsertionPoint();
1879+
}
1880+
18721881
// TODO(cir): This could be shared with classic codegen.
18731882
static CharUnits computeOffsetHint(ASTContext &astContext,
18741883
const CXXRecordDecl *src,
@@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc,
19541963
return Address{ptr, src.getAlignment()};
19551964
}
19561965

1966+
static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi,
1967+
CIRGenFunction &cgf, mlir::Location loc,
1968+
QualType srcRecordTy,
1969+
QualType destRecordTy,
1970+
cir::PointerType destCIRTy,
1971+
bool isRefCast, Address src) {
1972+
// Find all the inheritance paths from SrcRecordTy to DestRecordTy.
1973+
const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl();
1974+
const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl();
1975+
CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
1976+
/*DetectVirtual=*/false);
1977+
(void)destDecl->isDerivedFrom(srcDecl, paths);
1978+
1979+
// Find an offset within `destDecl` where a `srcDecl` instance and its vptr
1980+
// might appear.
1981+
std::optional<CharUnits> offset;
1982+
for (const CXXBasePath &path : paths) {
1983+
// dynamic_cast only finds public inheritance paths.
1984+
if (path.Access != AS_public)
1985+
continue;
1986+
1987+
CharUnits pathOffset;
1988+
for (const CXXBasePathElement &pathElement : path) {
1989+
// Find the offset along this inheritance step.
1990+
const CXXRecordDecl *base =
1991+
pathElement.Base->getType()->getAsCXXRecordDecl();
1992+
if (pathElement.Base->isVirtual()) {
1993+
// For a virtual base class, we know that the derived class is exactly
1994+
// destDecl, so we can use the vbase offset from its layout.
1995+
const ASTRecordLayout &layout =
1996+
cgf.getContext().getASTRecordLayout(destDecl);
1997+
pathOffset = layout.getVBaseClassOffset(base);
1998+
} else {
1999+
const ASTRecordLayout &layout =
2000+
cgf.getContext().getASTRecordLayout(pathElement.Class);
2001+
pathOffset += layout.getBaseClassOffset(base);
2002+
}
2003+
}
2004+
2005+
if (!offset) {
2006+
offset = pathOffset;
2007+
} else if (offset != pathOffset) {
2008+
// base appears in at least two different places. Find the most-derived
2009+
// object and see if it's a DestDecl. Note that the most-derived object
2010+
// must be at least as aligned as this base class subobject, and must
2011+
// have a vptr at offset 0.
2012+
src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src);
2013+
srcDecl = destDecl;
2014+
offset = CharUnits::Zero();
2015+
break;
2016+
}
2017+
}
2018+
2019+
CIRGenBuilderTy &builder = cgf.getBuilder();
2020+
2021+
if (!offset) {
2022+
// If there are no public inheritance paths, the cast always fails.
2023+
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
2024+
if (isRefCast) {
2025+
mlir::Region *currentRegion = builder.getBlock()->getParent();
2026+
emitCallToBadCast(cgf, loc);
2027+
2028+
// The call to bad_cast will terminate the block. Create a new block to
2029+
// hold any follow up code.
2030+
builder.createBlock(currentRegion, currentRegion->end());
2031+
}
2032+
2033+
return nullPtrValue;
2034+
}
2035+
2036+
// Compare the vptr against the expected vptr for the destination type at
2037+
// this offset. Note that we do not know what type src points to in the case
2038+
// where the derived class multiply inherits from the base class so we can't
2039+
// use getVTablePtr, so we load the vptr directly instead.
2040+
2041+
mlir::Value expectedVPtr =
2042+
abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl);
2043+
2044+
// TODO(cir): handle address space here.
2045+
assert(!cir::MissingFeatures::addressSpace());
2046+
mlir::Type vptrTy = expectedVPtr.getType();
2047+
mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy);
2048+
Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy),
2049+
src.getAlignment());
2050+
mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr);
2051+
2052+
// TODO(cir): decorate SrcVPtr with TBAA info.
2053+
assert(!cir::MissingFeatures::opTBAA());
2054+
2055+
mlir::Value success =
2056+
builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr);
2057+
2058+
auto emitCastResult = [&] {
2059+
if (offset->isZero())
2060+
return builder.createBitcast(src.getPointer(), destCIRTy);
2061+
2062+
// TODO(cir): handle address space here.
2063+
assert(!cir::MissingFeatures::addressSpace());
2064+
mlir::Type u8PtrTy = builder.getUInt8PtrTy();
2065+
2066+
mlir::Value strideToApply =
2067+
builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity());
2068+
mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy);
2069+
mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy,
2070+
srcU8Ptr, strideToApply);
2071+
return builder.createBitcast(resultU8Ptr, destCIRTy);
2072+
};
2073+
2074+
if (isRefCast) {
2075+
mlir::Value failed = builder.createNot(success);
2076+
cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false,
2077+
[&](mlir::OpBuilder &, mlir::Location) {
2078+
emitCallToBadCast(cgf, loc);
2079+
});
2080+
return emitCastResult();
2081+
}
2082+
2083+
return cir::TernaryOp::create(
2084+
builder, loc, success,
2085+
[&](mlir::OpBuilder &, mlir::Location) {
2086+
auto result = emitCastResult();
2087+
builder.createYield(loc, result);
2088+
},
2089+
[&](mlir::OpBuilder &, mlir::Location) {
2090+
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
2091+
builder.createYield(loc, nullPtrValue);
2092+
})
2093+
.getResult();
2094+
}
2095+
19572096
static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf,
19582097
mlir::Location loc,
19592098
QualType srcRecordTy,
@@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf,
19952134
// if the dynamic type of the pointer is exactly the destination type.
19962135
if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
19972136
cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) {
1998-
cgm.errorNYI(loc, "emitExactDynamicCast");
1999-
return {};
2137+
CIRGenBuilderTy &builder = cgf.getBuilder();
2138+
// If this isn't a reference cast, check the pointer to see if it's null.
2139+
if (!isRefCast) {
2140+
mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer());
2141+
return cir::TernaryOp::create(
2142+
builder, loc, srcPtrIsNull,
2143+
[&](mlir::OpBuilder, mlir::Location) {
2144+
builder.createYield(
2145+
loc, builder.getNullPtr(destCIRTy, loc).getResult());
2146+
},
2147+
[&](mlir::OpBuilder &, mlir::Location) {
2148+
mlir::Value exactCast = emitExactDynamicCast(
2149+
*this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy,
2150+
isRefCast, src);
2151+
builder.createYield(loc, exactCast);
2152+
})
2153+
.getResult();
2154+
}
2155+
2156+
return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy,
2157+
destCIRTy, isRefCast, src);
20002158
}
20012159

20022160
cir::DynamicCastInfoAttr castInfo =

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
22232223
return mlir::success();
22242224
}
22252225

2226+
if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) {
2227+
// !cir.vptr is a special case, but it's just a pointer to LLVM.
2228+
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
2229+
/* isSigned=*/false);
2230+
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
2231+
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
2232+
return mlir::success();
2233+
}
2234+
22262235
if (mlir::isa<cir::FPTypeInterface>(type)) {
22272236
mlir::LLVM::FCmpPredicate kind =
22282237
convertCmpKindToFCmpPredicate(cmpOp.getKind());
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s
4+
// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s
6+
// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s
7+
8+
struct Base1 {
9+
virtual ~Base1();
10+
};
11+
12+
struct Base2 {
13+
virtual ~Base2();
14+
};
15+
16+
struct Derived final : Base1 {};
17+
18+
Derived *ptr_cast(Base1 *ptr) {
19+
return dynamic_cast<Derived *>(ptr);
20+
}
21+
22+
// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1
23+
// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
24+
// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null>
25+
// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]])
26+
// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true {
27+
// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
28+
// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived>
29+
// CIR-NEXT: }, false {
30+
// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
31+
// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
32+
// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
33+
// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
34+
// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true {
35+
// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
36+
// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived>
37+
// CIR-NEXT: }, false {
38+
// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
39+
// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived>
40+
// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
41+
// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived>
42+
// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
43+
44+
// Note: The LLVM output omits the label for the entry block (which is
45+
// implicitly %1), so we use %{{.*}} to match the implicit label in the
46+
// phi check.
47+
48+
// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]])
49+
// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null
50+
// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]]
51+
// LLVM: [[LABEL_NOTNULL]]:
52+
// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
53+
// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
54+
// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null
55+
// LLVM-NEXT: br label %[[LABEL_END]]
56+
// LLVM: [[LABEL_END]]:
57+
// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ]
58+
// LLVM-NEXT: ret ptr %[[RESULT]]
59+
// LLVM-NEXT: }
60+
61+
// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]])
62+
// OGCG-NEXT: entry:
63+
// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null
64+
// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
65+
// OGCG: [[LABEL_NOTNULL]]:
66+
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8
67+
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
68+
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]]
69+
// OGCG: [[LABEL_NULL]]:
70+
// OGCG-NEXT: br label %[[LABEL_END]]
71+
// OGCG: [[LABEL_END]]:
72+
// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
73+
// OGCG-NEXT: ret ptr %[[RESULT]]
74+
// OGCG-NEXT: }
75+
76+
Derived &ref_cast(Base1 &ref) {
77+
return dynamic_cast<Derived &>(ref);
78+
}
79+
80+
// CIR: cir.func {{.*}} @_Z8ref_castR5Base1
81+
// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
82+
// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
83+
// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
84+
// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
85+
// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
86+
// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool
87+
// CIR-NEXT: cir.if %[[FAILED]] {
88+
// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> ()
89+
// CIR-NEXT: cir.unreachable
90+
// CIR-NEXT: }
91+
// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
92+
93+
// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]])
94+
// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
95+
// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
96+
// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]]
97+
// LLVM: [[LABEL_FAIL]]:
98+
// LLVM-NEXT: tail call void @__cxa_bad_cast()
99+
// LLVM-NEXT: unreachable
100+
// LLVM: [[LABEL_OK]]:
101+
// LLVM-NEXT: ret ptr %[[SRC]]
102+
// LLVM-NEXT: }
103+
104+
// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]])
105+
// OGCG-NEXT: entry:
106+
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8
107+
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
108+
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]]
109+
// OGCG: [[LABEL_NULL]]:
110+
// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast()
111+
// OGCG-NEXT: unreachable
112+
// OGCG: [[LABEL_END]]:
113+
// OGCG-NEXT: ret ptr %[[REF]]
114+
// OGCG-NEXT: }

0 commit comments

Comments
 (0)