Skip to content

Commit b629981

Browse files
authored
[CIR] Add virtual base support to getAddressOfBaseClass (llvm#159162)
This patch enables calling virtual functions of virtual base classes of a derived class.
1 parent eef7a76 commit b629981

File tree

2 files changed

+110
-14
lines changed

2 files changed

+110
-14
lines changed

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -951,39 +951,55 @@ Address CIRGenFunction::getAddressOfBaseClass(
951951
bool nullCheckValue, SourceLocation loc) {
952952
assert(!path.empty() && "Base path should not be empty!");
953953

954+
CastExpr::path_const_iterator start = path.begin();
955+
const CXXRecordDecl *vBase = nullptr;
956+
954957
if ((*path.begin())->isVirtual()) {
955-
// The implementation here is actually complete, but let's flag this
956-
// as an error until the rest of the virtual base class support is in place.
957-
cgm.errorNYI(loc, "getAddrOfBaseClass: virtual base");
958-
return Address::invalid();
958+
vBase = (*start)->getType()->castAsCXXRecordDecl();
959+
++start;
959960
}
960961

961962
// Compute the static offset of the ultimate destination within its
962963
// allocating subobject (the virtual base, if there is one, or else
963964
// the "complete" object that we see).
964-
CharUnits nonVirtualOffset =
965-
cgm.computeNonVirtualBaseClassOffset(derived, path);
965+
CharUnits nonVirtualOffset = cgm.computeNonVirtualBaseClassOffset(
966+
vBase ? vBase : derived, {start, path.end()});
967+
968+
// If there's a virtual step, we can sometimes "devirtualize" it.
969+
// For now, that's limited to when the derived type is final.
970+
// TODO: "devirtualize" this for accesses to known-complete objects.
971+
if (vBase && derived->hasAttr<FinalAttr>()) {
972+
const ASTRecordLayout &layout = getContext().getASTRecordLayout(derived);
973+
CharUnits vBaseOffset = layout.getVBaseClassOffset(vBase);
974+
nonVirtualOffset += vBaseOffset;
975+
vBase = nullptr; // we no longer have a virtual step
976+
}
966977

967978
// Get the base pointer type.
968979
mlir::Type baseValueTy = convertType((path.end()[-1])->getType());
969980
assert(!cir::MissingFeatures::addressSpace());
970981

971-
// The if statement here is redundant now, but it will be needed when we add
972-
// support for virtual base classes.
973982
// If there is no virtual base, use cir.base_class_addr. It takes care of
974983
// the adjustment and the null pointer check.
975-
if (nonVirtualOffset.isZero()) {
984+
if (nonVirtualOffset.isZero() && !vBase) {
976985
assert(!cir::MissingFeatures::sanitizers());
977986
return builder.createBaseClassAddr(getLoc(loc), value, baseValueTy, 0,
978987
/*assumeNotNull=*/true);
979988
}
980989

981990
assert(!cir::MissingFeatures::sanitizers());
982991

983-
// Apply the offset
984-
value = builder.createBaseClassAddr(getLoc(loc), value, baseValueTy,
985-
nonVirtualOffset.getQuantity(),
986-
/*assumeNotNull=*/true);
992+
// Compute the virtual offset.
993+
mlir::Value virtualOffset = nullptr;
994+
if (vBase) {
995+
virtualOffset = cgm.getCXXABI().getVirtualBaseClassOffset(
996+
getLoc(loc), *this, value, derived, vBase);
997+
}
998+
999+
// Apply both offsets.
1000+
value = applyNonVirtualAndVirtualOffset(
1001+
getLoc(loc), *this, value, nonVirtualOffset, virtualOffset, derived,
1002+
vBase, baseValueTy, not nullCheckValue);
9871003

9881004
// Cast to the destination type.
9891005
value = value.withElementType(builder, baseValueTy);

clang/test/CIR/CodeGen/vbase.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,29 @@ class Base {
1313

1414
class Derived : public virtual Base {};
1515

16-
// This is just here to force the record types to be emitted.
1716
void f() {
1817
Derived d;
18+
d.f();
19+
}
20+
21+
class DerivedFinal final : public virtual Base {};
22+
23+
void g() {
24+
DerivedFinal df;
25+
df.f();
1926
}
2027

2128
// CIR: !rec_Base = !cir.record<class "Base" {!cir.vptr}>
2229
// CIR: !rec_Derived = !cir.record<class "Derived" {!rec_Base}>
30+
// CIR: !rec_DerivedFinal = !cir.record<class "DerivedFinal" {!rec_Base}>
2331

2432
// LLVM: %class.Derived = type { %class.Base }
2533
// LLVM: %class.Base = type { ptr }
34+
// LLVM: %class.DerivedFinal = type { %class.Base }
2635

2736
// OGCG: %class.Derived = type { %class.Base }
2837
// OGCG: %class.Base = type { ptr }
38+
// OGCG: %class.DerivedFinal = type { %class.Base }
2939

3040
// Test the constructor handling for a class with a virtual base.
3141
struct A {
@@ -47,6 +57,76 @@ void ppp() { B b; }
4757

4858
// OGCG: @_ZTV1B = linkonce_odr unnamed_addr constant { [3 x ptr] } { [3 x ptr] [ptr inttoptr (i64 12 to ptr), ptr null, ptr @_ZTI1B] }, comdat, align 8
4959

60+
// CIR: cir.func {{.*}}@_Z1fv() {
61+
// CIR: %[[D:.+]] = cir.alloca !rec_Derived, !cir.ptr<!rec_Derived>, ["d", init]
62+
// CIR: cir.call @_ZN7DerivedC1Ev(%[[D]]) nothrow : (!cir.ptr<!rec_Derived>) -> ()
63+
// CIR: %[[VPTR_PTR:.+]] = cir.vtable.get_vptr %[[D]] : !cir.ptr<!rec_Derived> -> !cir.ptr<!cir.vptr>
64+
// CIR: %[[VPTR:.+]] = cir.load {{.*}} %[[VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
65+
// CIR: %[[VPTR_I8:.+]] = cir.cast(bitcast, %[[VPTR]] : !cir.vptr), !cir.ptr<!u8i>
66+
// CIR: %[[NEG32:.+]] = cir.const #cir.int<-32> : !s64i
67+
// CIR: %[[ADJ_VPTR_I8:.+]] = cir.ptr_stride(%[[VPTR_I8]] : !cir.ptr<!u8i>, %[[NEG32]] : !s64i), !cir.ptr<!u8i>
68+
// CIR: %[[OFFSET_PTR:.+]] = cir.cast(bitcast, %[[ADJ_VPTR_I8]] : !cir.ptr<!u8i>), !cir.ptr<!s64i>
69+
// CIR: %[[OFFSET:.+]] = cir.load {{.*}} %[[OFFSET_PTR]] : !cir.ptr<!s64i>, !s64i
70+
// CIR: %[[D_I8:.+]] = cir.cast(bitcast, %[[D]] : !cir.ptr<!rec_Derived>), !cir.ptr<!u8i>
71+
// CIR: %[[ADJ_THIS_I8:.+]] = cir.ptr_stride(%[[D_I8]] : !cir.ptr<!u8i>, %[[OFFSET]] : !s64i), !cir.ptr<!u8i>
72+
// CIR: %[[ADJ_THIS_D:.+]] = cir.cast(bitcast, %[[ADJ_THIS_I8]] : !cir.ptr<!u8i>), !cir.ptr<!rec_Derived>
73+
// CIR: %[[BASE_THIS:.+]] = cir.cast(bitcast, %[[ADJ_THIS_D]] : !cir.ptr<!rec_Derived>), !cir.ptr<!rec_Base>
74+
// CIR: %[[BASE_VPTR_PTR:.+]] = cir.vtable.get_vptr %[[BASE_THIS]] : !cir.ptr<!rec_Base> -> !cir.ptr<!cir.vptr>
75+
// CIR: %[[BASE_VPTR:.+]] = cir.load {{.*}} %[[BASE_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
76+
// CIR: %[[SLOT_PTR:.+]] = cir.vtable.get_virtual_fn_addr %[[BASE_VPTR]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>
77+
// CIR: %[[FN:.+]] = cir.load {{.*}} %[[SLOT_PTR]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>
78+
// CIR: cir.call %[[FN]](%[[BASE_THIS]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>, !cir.ptr<!rec_Base>) -> ()
79+
// CIR: cir.return
80+
81+
// CIR: cir.func {{.*}}@_Z1gv() {
82+
// CIR: %[[DF:.+]] = cir.alloca !rec_DerivedFinal, !cir.ptr<!rec_DerivedFinal>, ["df", init]
83+
// CIR: cir.call @_ZN12DerivedFinalC1Ev(%[[DF]]) nothrow : (!cir.ptr<!rec_DerivedFinal>) -> ()
84+
// CIR: %[[BASE_THIS_2:.+]] = cir.base_class_addr %[[DF]] : !cir.ptr<!rec_DerivedFinal> nonnull [0] -> !cir.ptr<!rec_Base>
85+
// CIR: %[[BASE_VPTR_PTR_2:.+]] = cir.vtable.get_vptr %[[BASE_THIS_2]] : !cir.ptr<!rec_Base> -> !cir.ptr<!cir.vptr>
86+
// CIR: %[[BASE_VPTR_2:.+]] = cir.load {{.*}} %[[BASE_VPTR_PTR_2]] : !cir.ptr<!cir.vptr>, !cir.vptr
87+
// CIR: %[[SLOT_PTR_2:.+]] = cir.vtable.get_virtual_fn_addr %[[BASE_VPTR_2]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>
88+
// CIR: %[[FN_2:.+]] = cir.load {{.*}} %[[SLOT_PTR_2]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>
89+
// CIR: cir.call %[[FN_2]](%[[BASE_THIS_2]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>, !cir.ptr<!rec_Base>) -> ()
90+
// CIR: cir.return
91+
92+
// LLVM: define {{.*}}void @_Z1fv()
93+
// LLVM: %[[D:.+]] = alloca {{.*}}
94+
// LLVM: call void @_ZN7DerivedC1Ev(ptr %[[D]])
95+
// LLVM: %[[VPTR_ADDR:.+]] = load ptr, ptr %[[D]]
96+
// LLVM: %[[NEG32_PTR:.+]] = getelementptr i8, ptr %[[VPTR_ADDR]], i64 -32
97+
// LLVM: %[[OFF:.+]] = load i64, ptr %[[NEG32_PTR]]
98+
// LLVM: %[[ADJ_THIS:.+]] = getelementptr i8, ptr %[[D]], i64 %[[OFF]]
99+
// LLVM: %[[VFN_TAB:.+]] = load ptr, ptr %[[ADJ_THIS]]
100+
// LLVM: %[[SLOT0:.+]] = getelementptr inbounds ptr, ptr %[[VFN_TAB]], i32 0
101+
// LLVM: %[[VFN:.+]] = load ptr, ptr %[[SLOT0]]
102+
// LLVM: call void %[[VFN]](ptr %[[ADJ_THIS]])
103+
// LLVM: ret void
104+
105+
// LLVM: define {{.*}}void @_Z1gv()
106+
// LLVM: %[[DF:.+]] = alloca {{.*}}
107+
// LLVM: call void @_ZN12DerivedFinalC1Ev(ptr %[[DF]])
108+
// LLVM: %[[VPTR2:.+]] = load ptr, ptr %[[DF]]
109+
// LLVM: %[[SLOT0_2:.+]] = getelementptr inbounds ptr, ptr %[[VPTR2]], i32 0
110+
// LLVM: %[[VFN2:.+]] = load ptr, ptr %[[SLOT0_2]]
111+
// LLVM: call void %[[VFN2]](ptr %[[DF]])
112+
// LLVM: ret void
113+
114+
// OGCG: define {{.*}}void @_Z1fv()
115+
// OGCG: %[[D:.+]] = alloca {{.*}}
116+
// OGCG: call void @_ZN7DerivedC1Ev(ptr {{.*}} %[[D]])
117+
// OGCG: %[[VTABLE:.+]] = load ptr, ptr %[[D]]
118+
// OGCG: %[[NEG32_PTR:.+]] = getelementptr i8, ptr %[[VTABLE]], i64 -32
119+
// OGCG: %[[OFF:.+]] = load i64, ptr %[[NEG32_PTR]]
120+
// OGCG: %[[ADJ_THIS:.+]] = getelementptr inbounds i8, ptr %[[D]], i64 %[[OFF]]
121+
// OGCG: call void @_ZN4Base1fEv(ptr {{.*}} %[[ADJ_THIS]])
122+
// OGCG: ret void
123+
124+
// OGCG: define {{.*}}void @_Z1gv()
125+
// OGCG: %[[DF:.+]] = alloca {{.*}}
126+
// OGCG: call void @_ZN12DerivedFinalC1Ev(ptr {{.*}} %[[DF]])
127+
// OGCG: call void @_ZN4Base1fEv(ptr {{.*}} %[[DF]])
128+
// OGCG: ret void
129+
50130
// Constructor for B
51131
// CIR: cir.func comdat linkonce_odr @_ZN1BC1Ev(%arg0: !cir.ptr<!rec_B>
52132
// CIR: %[[THIS_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["this", init]

0 commit comments

Comments
 (0)