Skip to content

Commit b801199

Browse files
committed
[CIR] Upstream initial support for union type
1 parent 1b2671f commit b801199

File tree

7 files changed

+276
-36
lines changed

7 files changed

+276
-36
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def CIR_RecordType : CIR_Type<"Record", "record",
494494
bool isComplete() const { return !isIncomplete(); };
495495
bool isIncomplete() const;
496496

497+
mlir::Type getLargestMember(const mlir::DataLayout &dataLayout) const;
497498
size_t getNumElements() const { return getMembers().size(); };
498499
std::string getKindAsStr() {
499500
switch (getKind()) {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,25 @@ LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) {
317317
}
318318

319319
unsigned recordCVR = base.getVRQualifiers();
320-
if (rec->isUnion()) {
321-
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: union");
322-
return LValue();
323-
}
324320

325-
assert(!cir::MissingFeatures::preservedAccessIndexRegion());
326321
llvm::StringRef fieldName = field->getName();
327-
const CIRGenRecordLayout &layout =
328-
cgm.getTypes().getCIRGenRecordLayout(field->getParent());
329-
unsigned fieldIndex = layout.getCIRFieldNo(field);
330322

331-
assert(!cir::MissingFeatures::lambdaFieldToName());
323+
if (rec->isUnion()) {
324+
unsigned fieldIndex = field->getFieldIndex();
325+
assert(!cir::MissingFeatures::lambdaFieldToName());
326+
addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex);
327+
328+
} else {
329+
assert(!cir::MissingFeatures::preservedAccessIndexRegion());
330+
331+
const CIRGenRecordLayout &layout =
332+
cgm.getTypes().getCIRGenRecordLayout(field->getParent());
333+
unsigned fieldIndex = layout.getCIRFieldNo(field);
332334

333-
addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex);
335+
assert(!cir::MissingFeatures::lambdaFieldToName());
336+
337+
addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex);
338+
}
334339

335340
// If this is a reference field, load the reference right now.
336341
if (fieldType->isReferenceType()) {

clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--- CIRGenExprAgg.cpp - Emit CIR Code from Aggregate Expressions -----===//
1+
//===- CIRGenExprAggregrate.cpp - Emit CIR Code from Aggregate Expressions ===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@ struct CIRRecordLowering final {
5656
};
5757
// The constructor.
5858
CIRRecordLowering(CIRGenTypes &cirGenTypes, const RecordDecl *recordDecl,
59-
bool isPacked);
59+
bool packed);
6060

6161
/// Constructs a MemberInfo instance from an offset and mlir::Type.
6262
MemberInfo makeStorageInfo(CharUnits offset, mlir::Type data) {
6363
return MemberInfo(offset, MemberInfo::InfoKind::Field, data);
6464
}
6565

6666
void lower();
67+
void lowerUnion();
6768

6869
/// Determines if we need a packed llvm struct.
6970
void determinePacked();
@@ -83,6 +84,10 @@ struct CIRRecordLowering final {
8384
return CharUnits::fromQuantity(dataLayout.layout.getTypeABIAlignment(Ty));
8485
}
8586

87+
bool isZeroInitializable(const FieldDecl *fd) {
88+
return cirGenTypes.isZeroInitializable(fd->getType());
89+
}
90+
8691
/// Wraps cir::IntType with some implicit arguments.
8792
mlir::Type getUIntNType(uint64_t numBits) {
8893
unsigned alignedBits = llvm::PowerOf2Ceil(numBits);
@@ -121,6 +126,13 @@ struct CIRRecordLowering final {
121126
/// Fills out the structures that are ultimately consumed.
122127
void fillOutputFields();
123128

129+
void appendPaddingBytes(CharUnits size) {
130+
if (!size.isZero()) {
131+
fieldTypes.push_back(getByteArrayType(size));
132+
padded = true;
133+
}
134+
}
135+
124136
CIRGenTypes &cirGenTypes;
125137
CIRGenBuilderTy &builder;
126138
const ASTContext &astContext;
@@ -136,6 +148,8 @@ struct CIRRecordLowering final {
136148
LLVM_PREFERRED_TYPE(bool)
137149
unsigned zeroInitializable : 1;
138150
LLVM_PREFERRED_TYPE(bool)
151+
unsigned zeroInitializableAsBase : 1;
152+
LLVM_PREFERRED_TYPE(bool)
139153
unsigned packed : 1;
140154
LLVM_PREFERRED_TYPE(bool)
141155
unsigned padded : 1;
@@ -147,19 +161,19 @@ struct CIRRecordLowering final {
147161
} // namespace
148162

149163
CIRRecordLowering::CIRRecordLowering(CIRGenTypes &cirGenTypes,
150-
const RecordDecl *recordDecl,
151-
bool isPacked)
164+
const RecordDecl *recordDecl, bool packed)
152165
: cirGenTypes(cirGenTypes), builder(cirGenTypes.getBuilder()),
153166
astContext(cirGenTypes.getASTContext()), recordDecl(recordDecl),
154167
astRecordLayout(
155168
cirGenTypes.getASTContext().getASTRecordLayout(recordDecl)),
156169
dataLayout(cirGenTypes.getCGModule().getModule()),
157-
zeroInitializable(true), packed(isPacked), padded(false) {}
170+
zeroInitializable(true), zeroInitializableAsBase(true), packed(packed),
171+
padded(false) {}
158172

159173
void CIRRecordLowering::lower() {
160174
if (recordDecl->isUnion()) {
161-
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
162-
"lower: union");
175+
lowerUnion();
176+
assert(!cir::MissingFeatures::bitfields());
163177
return;
164178
}
165179

@@ -306,3 +320,71 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
306320
// TODO: implement verification
307321
return rl;
308322
}
323+
324+
void CIRRecordLowering::lowerUnion() {
325+
CharUnits layoutSize = astRecordLayout.getSize();
326+
mlir::Type storageType = nullptr;
327+
bool seenNamedMember = false;
328+
329+
// Iterate through the fields setting bitFieldInfo and the Fields array. Also
330+
// locate the "most appropriate" storage type. The heuristic for finding the
331+
// storage type isn't necessary, the first (non-0-length-bitfield) field's
332+
// type would work fine and be simpler but would be different than what we've
333+
// been doing and cause lit tests to change.
334+
for (const FieldDecl *field : recordDecl->fields()) {
335+
mlir::Type fieldType;
336+
if (field->isBitField())
337+
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
338+
"bitfields in lowerUnion");
339+
else
340+
fieldType = getStorageType(field);
341+
342+
fields[field->getCanonicalDecl()] = 0;
343+
344+
// Compute zero-initializable status.
345+
// This union might not be zero initialized: it may contain a pointer to
346+
// data member which might have some exotic initialization sequence.
347+
// If this is the case, then we aught not to try and come up with a "better"
348+
// type, it might not be very easy to come up with a Constant which
349+
// correctly initializes it.
350+
if (!seenNamedMember) {
351+
seenNamedMember = field->getIdentifier();
352+
if (!seenNamedMember)
353+
if (const RecordDecl *fieldRD = field->getType()->getAsRecordDecl())
354+
seenNamedMember = fieldRD->findFirstNamedDataMember();
355+
if (seenNamedMember && !isZeroInitializable(field)) {
356+
zeroInitializable = zeroInitializableAsBase = false;
357+
storageType = fieldType;
358+
}
359+
}
360+
361+
// Because our union isn't zero initializable, we won't be getting a better
362+
// storage type.
363+
if (!zeroInitializable)
364+
continue;
365+
366+
// Conditionally update our storage type if we've got a new "better" one.
367+
if (!storageType || getAlignment(fieldType) > getAlignment(storageType) ||
368+
(getAlignment(fieldType) == getAlignment(storageType) &&
369+
getSize(fieldType) > getSize(storageType)))
370+
storageType = fieldType;
371+
372+
// NOTE(cir): Track all union member's types, not just the largest one. It
373+
// allows for proper type-checking and retain more info for analisys.
374+
fieldTypes.push_back(fieldType);
375+
}
376+
377+
if (!storageType)
378+
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
379+
"No-storage Union NYI");
380+
381+
if (layoutSize < getSize(storageType))
382+
storageType = getByteArrayType(layoutSize);
383+
384+
// NOTE(cir): Defer padding calculations to the lowering process.
385+
appendPaddingBytes(layoutSize - getSize(storageType));
386+
387+
// Set packed if we need it.
388+
if (layoutSize % getAlignment(storageType))
389+
packed = true;
390+
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,17 +230,41 @@ void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) {
230230
llvm_unreachable("failed to complete record");
231231
}
232232

233+
/// Return the largest member of in the type.
234+
///
235+
/// Recurses into union members never returning a union as the largest member.
236+
Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
237+
assert(isUnion() && "Only call getLargestMember on unions");
238+
Type largestMember;
239+
unsigned largestMemberSize = 0;
240+
unsigned numElements = getNumElements();
241+
auto members = getMembers();
242+
if (getPadded())
243+
numElements -= 1; // The last element is padding.
244+
for (unsigned i = 0; i < numElements; ++i) {
245+
Type ty = members[i];
246+
if (!largestMember ||
247+
dataLayout.getTypeABIAlignment(ty) >
248+
dataLayout.getTypeABIAlignment(largestMember) ||
249+
(dataLayout.getTypeABIAlignment(ty) ==
250+
dataLayout.getTypeABIAlignment(largestMember) &&
251+
dataLayout.getTypeSize(ty) > largestMemberSize)) {
252+
largestMember = ty;
253+
largestMemberSize = dataLayout.getTypeSize(largestMember);
254+
}
255+
}
256+
return largestMember;
257+
}
258+
233259
//===----------------------------------------------------------------------===//
234260
// Data Layout information for types
235261
//===----------------------------------------------------------------------===//
236262

237263
llvm::TypeSize
238264
RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
239265
mlir::DataLayoutEntryListRef params) const {
240-
if (isUnion()) {
241-
// TODO(CIR): Implement union layout.
242-
return llvm::TypeSize::getFixed(8);
243-
}
266+
if (isUnion())
267+
return dataLayout.getTypeSize(getLargestMember(dataLayout));
244268

245269
unsigned recordSize = computeStructSize(dataLayout);
246270
return llvm::TypeSize::getFixed(recordSize * 8);
@@ -249,10 +273,8 @@ RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
249273
uint64_t
250274
RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
251275
::mlir::DataLayoutEntryListRef params) const {
252-
if (isUnion()) {
253-
// TODO(CIR): Implement union layout.
254-
return 8;
255-
}
276+
if (isUnion())
277+
return dataLayout.getTypeABIAlignment(getLargestMember(dataLayout));
256278

257279
// Packed structures always have an ABI alignment of 1.
258280
if (getPacked())
@@ -268,8 +290,6 @@ RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const {
268290
unsigned recordSize = 0;
269291
uint64_t recordAlignment = 1;
270292

271-
// We can't use a range-based for loop here because we might be ignoring the
272-
// last element.
273293
for (mlir::Type ty : getMembers()) {
274294
// This assumes that we're calculating size based on the ABI alignment, not
275295
// the preferred alignment for each type.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
14311431
break;
14321432
// Unions are lowered as only the largest member.
14331433
case cir::RecordType::Union:
1434-
llvm_unreachable("Lowering of unions is NYI");
1434+
if (auto largestMember = type.getLargestMember(dataLayout))
1435+
llvmMembers.push_back(
1436+
convertTypeForMemory(converter, dataLayout, largestMember));
1437+
if (type.getPadded()) {
1438+
auto last = *type.getMembers().rbegin();
1439+
llvmMembers.push_back(
1440+
convertTypeForMemory(converter, dataLayout, last));
1441+
}
14351442
break;
14361443
}
14371444

@@ -1604,7 +1611,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
16041611
return mlir::success();
16051612
}
16061613
case cir::RecordType::Union:
1607-
return op.emitError() << "NYI: union get_member lowering";
1614+
// Union members share the address space, so we just need a bitcast to
1615+
// conform to type-checking.
1616+
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
1617+
adaptor.getAddr());
1618+
return mlir::success();
16081619
}
16091620
}
16101621

0 commit comments

Comments
 (0)