Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
isUsed = true;
}

mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
cir::FuncOp callee,
ArrayRef<mlir::Value> args) {
// TODO(cir): set the calling convention to this runtime call.
assert(!cir::MissingFeatures::opFuncCallingConv());

cir::CallOp call = builder.createCallOp(loc, callee, args);
assert(call->getNumResults() <= 1 &&
"runtime functions have at most 1 result");

if (call->getNumResults() == 0)
return nullptr;

return call->getResult(0);
}

void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
clang::QualType argType) {
assert(argType->isReferenceType() == e->isGLValue() &&
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache {

void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty);

mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee,
llvm::ArrayRef<mlir::Value> args = {});

/// Emit the computation of the specified expression of scalar type.
mlir::Value emitScalarExpr(const clang::Expr *e);

Expand Down
162 changes: 160 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) {
return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast");
}

static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) {
// TODO(cir): set the calling convention to the runtime function.
assert(!cir::MissingFeatures::opFuncCallingConv());

cgf.emitRuntimeCall(loc, getBadCastFn(cgf));
cir::UnreachableOp::create(cgf.getBuilder(), loc);
cgf.getBuilder().clearInsertionPoint();
}

// TODO(cir): This could be shared with classic codegen.
static CharUnits computeOffsetHint(ASTContext &astContext,
const CXXRecordDecl *src,
Expand Down Expand Up @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc,
return Address{ptr, src.getAlignment()};
}

static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi,
CIRGenFunction &cgf, mlir::Location loc,
QualType srcRecordTy,
QualType destRecordTy,
cir::PointerType destCIRTy,
bool isRefCast, Address src) {
// Find all the inheritance paths from SrcRecordTy to DestRecordTy.
const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl();
const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl();
CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
/*DetectVirtual=*/false);
(void)destDecl->isDerivedFrom(srcDecl, paths);

// Find an offset within `destDecl` where a `srcDecl` instance and its vptr
// might appear.
std::optional<CharUnits> offset;
for (const CXXBasePath &path : paths) {
// dynamic_cast only finds public inheritance paths.
if (path.Access != AS_public)
continue;

CharUnits pathOffset;
for (const CXXBasePathElement &pathElement : path) {
// Find the offset along this inheritance step.
const CXXRecordDecl *base =
pathElement.Base->getType()->getAsCXXRecordDecl();
if (pathElement.Base->isVirtual()) {
// For a virtual base class, we know that the derived class is exactly
// destDecl, so we can use the vbase offset from its layout.
const ASTRecordLayout &layout =
cgf.getContext().getASTRecordLayout(destDecl);
pathOffset = layout.getVBaseClassOffset(base);
} else {
const ASTRecordLayout &layout =
cgf.getContext().getASTRecordLayout(pathElement.Class);
pathOffset += layout.getBaseClassOffset(base);
}
}

if (!offset) {
offset = pathOffset;
} else if (offset != pathOffset) {
// base appears in at least two different places. Find the most-derived
// object and see if it's a DestDecl. Note that the most-derived object
// must be at least as aligned as this base class subobject, and must
// have a vptr at offset 0.
src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src);
srcDecl = destDecl;
offset = CharUnits::Zero();
break;
}
}

CIRGenBuilderTy &builder = cgf.getBuilder();

if (!offset) {
// If there are no public inheritance paths, the cast always fails.
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
if (isRefCast) {
mlir::Region *currentRegion = builder.getBlock()->getParent();
emitCallToBadCast(cgf, loc);

// The call to bad_cast will terminate the block. Create a new block to
// hold any follow up code.
builder.createBlock(currentRegion, currentRegion->end());
}

return nullPtrValue;
}

// Compare the vptr against the expected vptr for the destination type at
// this offset. Note that we do not know what type src points to in the case
// where the derived class multiply inherits from the base class so we can't
// use getVTablePtr, so we load the vptr directly instead.

mlir::Value expectedVPtr =
abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl);

// TODO(cir): handle address space here.
assert(!cir::MissingFeatures::addressSpace());
mlir::Type vptrTy = expectedVPtr.getType();
mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy);
Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy),
src.getAlignment());
mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr);

// TODO(cir): decorate SrcVPtr with TBAA info.
assert(!cir::MissingFeatures::opTBAA());

mlir::Value success =
builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr);

auto emitCastResult = [&] {
if (offset->isZero())
return builder.createBitcast(src.getPointer(), destCIRTy);

// TODO(cir): handle address space here.
assert(!cir::MissingFeatures::addressSpace());
mlir::Type u8PtrTy = builder.getUInt8PtrTy();

mlir::Value strideToApply =
builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity());
mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy);
mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy,
srcU8Ptr, strideToApply);
return builder.createBitcast(resultU8Ptr, destCIRTy);
};

if (isRefCast) {
mlir::Value failed = builder.createNot(success);
cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false,
[&](mlir::OpBuilder &, mlir::Location) {
emitCallToBadCast(cgf, loc);
});
return emitCastResult();
}

return cir::TernaryOp::create(
builder, loc, success,
[&](mlir::OpBuilder &, mlir::Location) {
auto result = emitCastResult();
builder.createYield(loc, result);
},
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
builder.createYield(loc, nullPtrValue);
})
.getResult();
}

static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf,
mlir::Location loc,
QualType srcRecordTy,
Expand Down Expand Up @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf,
// if the dynamic type of the pointer is exactly the destination type.
if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) {
cgm.errorNYI(loc, "emitExactDynamicCast");
return {};
CIRGenBuilderTy &builder = cgf.getBuilder();
// If this isn't a reference cast, check the pointer to see if it's null.
if (!isRefCast) {
mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer());
return cir::TernaryOp::create(
builder, loc, srcPtrIsNull,
[&](mlir::OpBuilder, mlir::Location) {
builder.createYield(
loc, builder.getNullPtr(destCIRTy, loc).getResult());
},
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value exactCast = emitExactDynamicCast(
*this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy,
isRefCast, src);
builder.createYield(loc, exactCast);
})
.getResult();
}

return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy,
destCIRTy, isRefCast, src);
}

cir::DynamicCastInfoAttr castInfo =
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
return mlir::success();
}

if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) {
// !cir.vptr is a special case, but it's just a pointer to LLVM.
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
/* isSigned=*/false);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
return mlir::success();
}

if (mlir::isa<cir::FPTypeInterface>(type)) {
mlir::LLVM::FCmpPredicate kind =
convertCmpKindToFCmpPredicate(cmpOp.getKind());
Expand Down
114 changes: 114 additions & 0 deletions clang/test/CIR/CodeGen/dynamic-cast-exact.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s
// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s
// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s

struct Base1 {
virtual ~Base1();
};

struct Base2 {
virtual ~Base2();
};

struct Derived final : Base1 {};

Derived *ptr_cast(Base1 *ptr) {
return dynamic_cast<Derived *>(ptr);
}

// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1
// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null>
// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]])
// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true {
// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived>
// CIR-NEXT: }, false {
// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true {
// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived>
// CIR-NEXT: }, false {
// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived>
// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived>
// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>

// Note: The LLVM output omits the label for the entry block (which is
// implicitly %1), so we use %{{.*}} to match the implicit label in the
// phi check.

// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]])
// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null
// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]]
// LLVM: [[LABEL_NOTNULL]]:
// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null
// LLVM-NEXT: br label %[[LABEL_END]]
// LLVM: [[LABEL_END]]:
// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ]
// LLVM-NEXT: ret ptr %[[RESULT]]
// LLVM-NEXT: }

// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]])
// OGCG-NEXT: entry:
// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null
// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
// OGCG: [[LABEL_NOTNULL]]:
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]]
// OGCG: [[LABEL_NULL]]:
// OGCG-NEXT: br label %[[LABEL_END]]
// OGCG: [[LABEL_END]]:
// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
// OGCG-NEXT: ret ptr %[[RESULT]]
// OGCG-NEXT: }

Derived &ref_cast(Base1 &ref) {
return dynamic_cast<Derived &>(ref);
}

// CIR: cir.func {{.*}} @_Z8ref_castR5Base1
// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool
// CIR-NEXT: cir.if %[[FAILED]] {
// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> ()
// CIR-NEXT: cir.unreachable
// CIR-NEXT: }
// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>

// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]])
// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]]
// LLVM: [[LABEL_FAIL]]:
// LLVM-NEXT: tail call void @__cxa_bad_cast()
// LLVM-NEXT: unreachable
// LLVM: [[LABEL_OK]]:
// LLVM-NEXT: ret ptr %[[SRC]]
// LLVM-NEXT: }

// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]])
// OGCG-NEXT: entry:
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]]
// OGCG: [[LABEL_NULL]]:
// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast()
// OGCG-NEXT: unreachable
// OGCG: [[LABEL_END]]:
// OGCG-NEXT: ret ptr %[[REF]]
// OGCG-NEXT: }