@@ -40,15 +40,18 @@ struct CIRRecordLowering final {
4040 // member type that ensures correct rounding.
4141 struct MemberInfo final {
4242 CharUnits offset;
43- enum class InfoKind { Field } kind;
43+ enum class InfoKind { Field, Base } kind;
4444 mlir::Type data;
4545 union {
4646 const FieldDecl *fieldDecl;
47- // CXXRecordDecl will be used here when base types are supported.
47+ const CXXRecordDecl *cxxRecordDecl;
4848 };
4949 MemberInfo (CharUnits offset, InfoKind kind, mlir::Type data,
5050 const FieldDecl *fieldDecl = nullptr )
51- : offset(offset), kind(kind), data(data), fieldDecl(fieldDecl) {};
51+ : offset{offset}, kind{kind}, data{data}, fieldDecl{fieldDecl} {}
52+ MemberInfo (CharUnits offset, InfoKind kind, mlir::Type data,
53+ const CXXRecordDecl *rd)
54+ : offset{offset}, kind{kind}, data{data}, cxxRecordDecl{rd} {}
5255 // MemberInfos are sorted so we define a < operator.
5356 bool operator <(const MemberInfo &other) const {
5457 return offset < other.offset ;
@@ -71,6 +74,8 @@ struct CIRRecordLowering final {
7174 // / Inserts padding everywhere it's needed.
7275 void insertPadding ();
7376
77+ void accumulateBases (const CXXRecordDecl *cxxRecordDecl);
78+ void accumulateVPtrs ();
7479 void accumulateFields ();
7580
7681 CharUnits bitsToCharUnits (uint64_t bitOffset) {
@@ -89,6 +94,9 @@ struct CIRRecordLowering final {
8994 bool isZeroInitializable (const FieldDecl *fd) {
9095 return cirGenTypes.isZeroInitializable (fd->getType ());
9196 }
97+ bool isZeroInitializable (const RecordDecl *rd) {
98+ return cirGenTypes.isZeroInitializable (rd);
99+ }
92100
93101 // / Wraps cir::IntType with some implicit arguments.
94102 mlir::Type getUIntNType (uint64_t numBits) {
@@ -112,6 +120,11 @@ struct CIRRecordLowering final {
112120 : cir::ArrayType::get (type, numberOfChars.getQuantity ());
113121 }
114122
123+ // Gets the CIR BaseSubobject type from a CXXRecordDecl.
124+ mlir::Type getStorageType (const CXXRecordDecl *RD) {
125+ return cirGenTypes.getCIRGenRecordLayout (RD).getBaseSubobjectCIRType ();
126+ }
127+
115128 mlir::Type getStorageType (const FieldDecl *fieldDecl) {
116129 mlir::Type type = cirGenTypes.convertTypeForMem (fieldDecl->getType ());
117130 if (fieldDecl->isBitField ()) {
@@ -145,6 +158,7 @@ struct CIRRecordLowering final {
145158 // Output fields, consumed by CIRGenTypes::computeRecordLayout
146159 llvm::SmallVector<mlir::Type, 16 > fieldTypes;
147160 llvm::DenseMap<const FieldDecl *, unsigned > fieldIdxMap;
161+ llvm::DenseMap<const CXXRecordDecl *, unsigned > nonVirtualBases;
148162 cir::CIRDataLayout dataLayout;
149163
150164 LLVM_PREFERRED_TYPE (bool )
@@ -179,24 +193,20 @@ void CIRRecordLowering::lower() {
179193 return ;
180194 }
181195
182- assert (!cir::MissingFeatures::cxxSupport ());
183-
196+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
184197 CharUnits size = astRecordLayout.getSize ();
185198
186199 accumulateFields ();
187200
188201 if (const auto *cxxRecordDecl = dyn_cast<CXXRecordDecl>(recordDecl)) {
189- if (cxxRecordDecl->getNumBases () > 0 ) {
190- CIRGenModule &cgm = cirGenTypes.getCGModule ();
191- cgm.errorNYI (recordDecl->getSourceRange (),
192- " CIRRecordLowering::lower: derived CXXRecordDecl" );
193- return ;
194- }
202+ accumulateVPtrs ();
203+ accumulateBases (cxxRecordDecl);
195204 if (members.empty ()) {
196205 appendPaddingBytes (size);
197206 assert (!cir::MissingFeatures::bitfields ());
198207 return ;
199208 }
209+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
200210 }
201211
202212 llvm::stable_sort (members);
@@ -223,8 +233,10 @@ void CIRRecordLowering::fillOutputFields() {
223233 fieldTypes.size () - 1 ;
224234 // A field without storage must be a bitfield.
225235 assert (!cir::MissingFeatures::bitfields ());
236+ } else if (member.kind == MemberInfo::InfoKind::Base) {
237+ nonVirtualBases[member.cxxRecordDecl ] = fieldTypes.size () - 1 ;
226238 }
227- assert (!cir::MissingFeatures::cxxSupport ());
239+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
228240 }
229241}
230242
@@ -254,9 +266,14 @@ void CIRRecordLowering::calculateZeroInit() {
254266 continue ;
255267 zeroInitializable = zeroInitializableAsBase = false ;
256268 return ;
269+ } else if (member.kind == MemberInfo::InfoKind::Base) {
270+ if (isZeroInitializable (member.cxxRecordDecl ))
271+ continue ;
272+ zeroInitializable = false ;
273+ if (member.kind == MemberInfo::InfoKind::Base)
274+ zeroInitializableAsBase = false ;
257275 }
258- // TODO(cir): handle base types
259- assert (!cir::MissingFeatures::cxxSupport ());
276+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
260277 }
261278}
262279
@@ -317,6 +334,27 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
317334 lowering.lower ();
318335
319336 // If we're in C++, compute the base subobject type.
337+ cir::RecordType baseTy;
338+ if (llvm::isa<CXXRecordDecl>(rd) && !rd->isUnion () &&
339+ !rd->hasAttr <FinalAttr>()) {
340+ baseTy = *ty;
341+ if (lowering.astRecordLayout .getNonVirtualSize () !=
342+ lowering.astRecordLayout .getSize ()) {
343+ CIRRecordLowering baseLowering (*this , rd, /* Packed=*/ lowering.packed );
344+ baseLowering.lower ();
345+ std::string baseIdentifier = getRecordTypeName (rd, " .base" );
346+ baseTy =
347+ builder.getCompleteRecordTy (baseLowering.fieldTypes , baseIdentifier,
348+ baseLowering.packed , baseLowering.padded );
349+ // TODO(cir): add something like addRecordTypeName
350+
351+ // BaseTy and Ty must agree on their packedness for getCIRFieldNo to work
352+ // on both of them with the same index.
353+ assert (lowering.packed == baseLowering.packed &&
354+ " Non-virtual and complete types must agree on packedness" );
355+ }
356+ }
357+
320358 if (llvm::isa<CXXRecordDecl>(rd) && !rd->isUnion () &&
321359 !rd->hasAttr <FinalAttr>()) {
322360 if (lowering.astRecordLayout .getNonVirtualSize () !=
@@ -332,10 +370,13 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
332370 ty->complete (lowering.fieldTypes , lowering.packed , lowering.padded );
333371
334372 auto rl = std::make_unique<CIRGenRecordLayout>(
335- ty ? *ty : cir::RecordType (), ( bool )lowering. zeroInitializable ,
336- (bool )lowering.zeroInitializableAsBase );
373+ ty ? *ty : cir::RecordType{}, baseTy ? baseTy : cir::RecordType{} ,
374+ (bool )lowering.zeroInitializable , ( bool )lowering. zeroInitializableAsBase );
337375
338376 assert (!cir::MissingFeatures::recordZeroInit ());
377+
378+ rl->nonVirtualBases .swap (lowering.nonVirtualBases );
379+
339380 assert (!cir::MissingFeatures::cxxSupport ());
340381 assert (!cir::MissingFeatures::bitfields ());
341382
@@ -415,3 +456,38 @@ void CIRRecordLowering::lowerUnion() {
415456 if (layoutSize % getAlignment (storageType))
416457 packed = true ;
417458}
459+
460+ void CIRRecordLowering::accumulateBases (const CXXRecordDecl *cxxRecordDecl) {
461+ // If we've got a primary virtual base, we need to add it with the bases.
462+ if (astRecordLayout.isPrimaryBaseVirtual ()) {
463+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
464+ " accumulateBases: primary virtual base" );
465+ }
466+
467+ // Accumulate the non-virtual bases.
468+ for ([[maybe_unused]] const auto &base : cxxRecordDecl->bases ()) {
469+ if (base.isVirtual ()) {
470+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
471+ " accumulateBases: virtual base" );
472+ continue ;
473+ }
474+ // Bases can be zero-sized even if not technically empty if they
475+ // contain only a trailing array member.
476+ const CXXRecordDecl *baseDecl = base.getType ()->getAsCXXRecordDecl ();
477+ if (!baseDecl->isEmpty () &&
478+ !astContext.getASTRecordLayout (baseDecl).getNonVirtualSize ().isZero ()) {
479+ members.push_back (MemberInfo (astRecordLayout.getBaseClassOffset (baseDecl),
480+ MemberInfo::InfoKind::Base,
481+ getStorageType (baseDecl), baseDecl));
482+ }
483+ }
484+ }
485+
486+ void CIRRecordLowering::accumulateVPtrs () {
487+ if (astRecordLayout.hasOwnVFPtr ())
488+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
489+ " accumulateVPtrs: hasOwnVFPtr" );
490+ if (astRecordLayout.hasOwnVBPtr ())
491+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
492+ " accumulateVPtrs: hasOwnVBPtr" );
493+ }
0 commit comments