Skip to content

Commit d147555

Browse files
andykaylorlanza
authored andcommitted
[CIR] Check for null source point before exact dynamic cast (#1953)
While upstreaming the code for handling exact dynamic casts, I noticed that we were not checking to see if the source pointer was null before using it to load the vtable. This change adds that check.
1 parent 98b04ee commit d147555

File tree

2 files changed

+78
-15
lines changed

2 files changed

+78
-15
lines changed

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2724,9 +2724,29 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &CGF,
27242724
// If the destination is effectively final, the cast succeeds if and only
27252725
// if the dynamic type of the pointer is exactly the destination type.
27262726
if (DestRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
2727-
CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
2727+
CGF.CGM.getCodeGenOpts().OptimizationLevel > 0) {
2728+
CIRGenBuilderTy &builder = CGF.getBuilder();
2729+
// If this isn't a reference cast, check the pointer to see if it's null.
2730+
if (!isRefCast) {
2731+
mlir::Value srcPtrIsNull = builder.createPtrIsNull(Src.getPointer());
2732+
return cir::TernaryOp::create(
2733+
builder, Loc, srcPtrIsNull,
2734+
[&](mlir::OpBuilder, mlir::Location) {
2735+
builder.createYield(
2736+
Loc, builder.getNullPtr(DestCIRTy, Loc).getResult());
2737+
},
2738+
[&](mlir::OpBuilder &, mlir::Location) {
2739+
mlir::Value exactCast = emitExactDynamicCast(
2740+
*this, CGF, Loc, SrcRecordTy, DestRecordTy, DestCIRTy,
2741+
isRefCast, Src);
2742+
builder.createYield(Loc, exactCast);
2743+
})
2744+
.getResult();
2745+
}
2746+
27282747
return emitExactDynamicCast(*this, CGF, Loc, SrcRecordTy, DestRecordTy,
27292748
DestCIRTy, isRefCast, Src);
2749+
}
27302750

27312751
auto castInfo = emitDynamicCastInfo(CGF, Loc, SrcRecordTy, DestRecordTy);
27322752
return CGF.getBuilder().createDynCast(Loc, Src.getPointer(), DestCIRTy,

clang/test/CIR/CodeGen/dynamic-cast-exact.cpp

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s
22
// RUN: FileCheck --input-file=%t.cir %s
3-
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -fno-clangir-call-conv-lowering -o %t.ll %s
4-
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -fno-clangir-call-conv-lowering -o %t-cir.ll %s
4+
// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s
6+
// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s
57

68
struct Base1 {
79
virtual ~Base1();
@@ -16,26 +18,55 @@ struct Derived final : Base1 {};
1618
Derived *ptr_cast(Base1 *ptr) {
1719
return dynamic_cast<Derived *>(ptr);
1820
// CHECK: %[[#SRC:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
19-
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
20-
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast bitcast %[[#SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
21-
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
22-
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
23-
// CHECK-NEXT: %{{.+}} = cir.ternary(%[[#SUCCESS]], true {
24-
// CHECK-NEXT: %[[#RES:]] = cir.cast bitcast %[[#SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
25-
// CHECK-NEXT: cir.yield %[[#RES]] : !cir.ptr<!rec_Derived>
21+
// CHECK-NEXT: %[[#SRC_IS_NONNULL:]] = cir.cast ptr_to_bool %[[#SRC]] : !cir.ptr<!rec_Base1> -> !cir.bool
22+
// CHECK-NEXT: %[[#SRC_IS_NULL:]] = cir.unary(not, %[[#SRC_IS_NONNULL]]) : !cir.bool, !cir.bool
23+
// CHECK-NEXT: %[[#RESULT:]] = cir.ternary(%4, true {
24+
// CHECK-NEXT: %[[#NULL_DEST_PTR:]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
25+
// CHECK-NEXT: cir.yield %[[#NULL_DEST_PTR]] : !cir.ptr<!rec_Derived>
2626
// CHECK-NEXT: }, false {
27-
// CHECK-NEXT: %[[#NULL:]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
28-
// CHECK-NEXT: cir.yield %[[#NULL]] : !cir.ptr<!rec_Derived>
27+
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
28+
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast bitcast %[[#SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
29+
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
30+
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
31+
// CHECK-NEXT: %[[#EXACT_RESULT:]] = cir.ternary(%[[#SUCCESS]], true {
32+
// CHECK-NEXT: %[[#RES:]] = cir.cast bitcast %[[#SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
33+
// CHECK-NEXT: cir.yield %[[#RES]] : !cir.ptr<!rec_Derived>
34+
// CHECK-NEXT: }, false {
35+
// CHECK-NEXT: %[[#NULL:]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
36+
// CHECK-NEXT: cir.yield %[[#NULL]] : !cir.ptr<!rec_Derived>
37+
// CHECK-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
38+
// CHECK-NEXT: cir.yield %[[#EXACT_RESULT]] : !cir.ptr<!rec_Derived>
2939
// CHECK-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
3040
}
3141

32-
// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr readonly captures(ret: address, provenance) %[[#SRC:]])
42+
// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[#SRC:]])
43+
// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %[[#SRC]], null
44+
// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[#LABEL_END:]], label %[[#LABEL_NONNULL:]]
45+
// LLVM: [[#LABEL_NONNULL]]
3346
// LLVM-NEXT: %[[#VPTR:]] = load ptr, ptr %[[#SRC]], align 8
3447
// LLVM-NEXT: %[[#SUCCESS:]] = icmp eq ptr %[[#VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
35-
// LLVM-NEXT: %[[RESULT:.+]] = select i1 %[[#SUCCESS]], ptr %[[#SRC]], ptr null
36-
// LLVM-NEXT: ret ptr %[[RESULT]]
48+
// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[#SUCCESS]], ptr %[[#SRC]], ptr null
49+
// LLVM-NEXT: br label %[[#LABEL_END]]
50+
// LLVM: [[#LABEL_END]]
51+
// LLVM-NEXT: %[[#RESULT:]] = phi ptr [ %[[EXACT_RESULT]], %[[#LABEL_NONNULL]] ], [ null, %{{.*}} ]
52+
// LLVM-NEXT: ret ptr %[[#RESULT]]
3753
// LLVM-NEXT: }
3854

55+
// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]])
56+
// OGCG-NEXT: entry:
57+
// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null
58+
// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
59+
// OGCG: [[LABEL_NOTNULL]]:
60+
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8
61+
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
62+
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]]
63+
// OGCG: [[LABEL_NULL]]:
64+
// OGCG-NEXT: br label %[[LABEL_END]]
65+
// OGCG: [[LABEL_END]]:
66+
// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
67+
// OGCG-NEXT: ret ptr %[[RESULT]]
68+
// OGCG-NEXT: }
69+
3970
Derived &ref_cast(Base1 &ref) {
4071
return dynamic_cast<Derived &>(ref);
4172
// CHECK: %[[#SRC:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
@@ -62,6 +93,18 @@ Derived &ref_cast(Base1 &ref) {
6293
// LLVM-NEXT: ret ptr %[[#SRC]]
6394
// LLVM-NEXT: }
6495

96+
// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]])
97+
// OGCG-NEXT: entry:
98+
// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8
99+
// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
100+
// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]]
101+
// OGCG: [[LABEL_NULL]]:
102+
// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast()
103+
// OGCG-NEXT: unreachable
104+
// OGCG: [[LABEL_END]]:
105+
// OGCG-NEXT: ret ptr %[[REF]]
106+
// OGCG-NEXT: }
107+
65108
Derived *ptr_cast_always_fail(Base2 *ptr) {
66109
return dynamic_cast<Derived *>(ptr);
67110
// CHECK: %{{.+}} = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base2>>, !cir.ptr<!rec_Base2>

0 commit comments

Comments
 (0)