diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0883728..88aef89ddd2b9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 3c36f5c697118..84b4ba293b3aa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index d54d2e9cb29e5..ef91288ab6155 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0243bf120f396..51dba33338cd6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp new file mode 100644 index 0000000000000..41a70ce53db5e --- /dev/null +++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s + +struct Base1 { + virtual ~Base1(); +}; + +struct Base2 { + virtual ~Base2(); +}; + +struct Derived final : Base1 {}; + +Derived *ptr_cast(Base1 *ptr) { + return dynamic_cast(ptr); +} + +// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr>, !cir.ptr +// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr +// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]]) +// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true { +// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr : !cir.ptr +// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr +// CIR-NEXT: }, false { +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = ) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr -> !cir.ptr +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true { +// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr -> !cir.ptr +// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr +// CIR-NEXT: }, false { +// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr : !cir.ptr +// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr +// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr + +// Note: The LLVM output omits the label for the entry block (which is +// implicitly %1), so we use %{{.*}} to match the implicit label in the +// phi check. + +// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null +// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]] +// LLVM: [[LABEL_NOTNULL]]: +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null +// LLVM-NEXT: br label %[[LABEL_END]] +// LLVM: [[LABEL_END]]: +// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ] +// LLVM-NEXT: ret ptr %[[RESULT]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null +// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]] +// OGCG: [[LABEL_NOTNULL]]: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: br label %[[LABEL_END]] +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ] +// OGCG-NEXT: ret ptr %[[RESULT]] +// OGCG-NEXT: } + +Derived &ref_cast(Base1 &ref) { + return dynamic_cast(ref); +} + +// CIR: cir.func {{.*}} @_Z8ref_castR5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr>, !cir.ptr +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = ) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr -> !cir.ptr +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool +// CIR-NEXT: cir.if %[[FAILED]] { +// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> () +// CIR-NEXT: cir.unreachable +// CIR-NEXT: } +// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr -> !cir.ptr + +// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]] +// LLVM: [[LABEL_FAIL]]: +// LLVM-NEXT: tail call void @__cxa_bad_cast() +// LLVM-NEXT: unreachable +// LLVM: [[LABEL_OK]]: +// LLVM-NEXT: ret ptr %[[SRC]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast() +// OGCG-NEXT: unreachable +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: ret ptr %[[REF]] +// OGCG-NEXT: }