Skip to content

Commit 89709ad

Browse files
committed
self review continued. Make FlattenAccessAndTypes not recursive and handle records correctly.
1 parent f4819b8 commit 89709ad

File tree

4 files changed

+92
-83
lines changed

4 files changed

+92
-83
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 81 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6361,62 +6361,87 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
63616361
}
63626362

63636363
void CodeGenFunction::FlattenAccessAndType(
6364-
Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
6365-
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
6366-
SmallVector<QualType> &FlatTypes) {
6364+
Address Addr, QualType AddrType,
6365+
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
6366+
SmallVectorImpl<QualType> &FlatTypes) {
6367+
// WorkList is list of type we are processing + the Index List to access
6368+
// the field of that type in Addr for use in a GEP
6369+
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
6370+
16>
6371+
WorkList;
63676372
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
6368-
if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
6369-
uint64_t Size = CAT->getZExtSize();
6370-
for (unsigned i = 0; i < Size; i++) {
6371-
// flatten each member of the array
6372-
// add index of this element to index list
6373-
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
6374-
IdxList.push_back(Idx);
6375-
// recur on this object
6376-
FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList,
6377-
FlatTypes);
6378-
// remove index of this element from index list
6379-
IdxList.pop_back();
6380-
}
6381-
} else if (const RecordType *RT = SrcTy->getAs<RecordType>()) {
6382-
RecordDecl *Record = RT->getDecl();
6383-
const CGRecordLayout &RL = getTypes().getCGRecordLayout(Record);
6384-
// do I need to check if its a cxx record decl?
6385-
6386-
for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
6387-
fieldIter != fieldEnd; ++fieldIter) {
6388-
// get the field number
6389-
unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
6390-
// can we just do *fieldIter->getFieldIndex();
6391-
// add that index to the index list
6392-
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, FieldNum);
6393-
IdxList.push_back(Idx);
6394-
// recur on the field
6395-
FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
6396-
FlatTypes);
6397-
// remove index of this element from index list
6398-
IdxList.pop_back();
6399-
}
6400-
} else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
6401-
llvm::Type *VTy = ConvertTypeForMem(SrcTy);
6402-
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
6403-
Address GEP =
6404-
Builder.CreateInBoundsGEP(Val, IdxList, VTy, Align, "vector.gep");
6405-
for (unsigned i = 0; i < VT->getNumElements(); i++) {
6406-
// add index to the list
6407-
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
6408-
// create gep. no need to recur since its always a scalar
6409-
// gep on vector is not recommended so combine gep with extract/insert
6410-
GEPList.push_back({GEP, Idx});
6411-
FlatTypes.push_back(VT->getElementType());
6373+
WorkList.push_back(
6374+
{AddrType,
6375+
{llvm::ConstantInt::get(
6376+
IdxTy,
6377+
0)}}); // Addr should be a pointer so we need to 'dereference' it
6378+
6379+
while (!WorkList.empty()) {
6380+
std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>> P =
6381+
WorkList.pop_back_val();
6382+
QualType T = P.first;
6383+
llvm::SmallVector<llvm::Value *, 4> IdxList = P.second;
6384+
T = T.getCanonicalType().getUnqualifiedType();
6385+
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
6386+
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
6387+
uint64_t Size = CAT->getZExtSize();
6388+
for (int64_t i = Size - 1; i > -1; i--) {
6389+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6390+
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, i));
6391+
WorkList.insert(WorkList.end(), {CAT->getElementType(), IdxListCopy});
6392+
}
6393+
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
6394+
const RecordDecl *Record = RT->getDecl();
6395+
if (Record->isUnion()) {
6396+
IdxList.push_back(llvm::ConstantInt::get(IdxTy, 0));
6397+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6398+
CharUnits Align = getContext().getTypeAlignInChars(T);
6399+
Address GEP =
6400+
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "union.gep");
6401+
AccessList.push_back({GEP, NULL});
6402+
FlatTypes.push_back(T);
6403+
continue;
6404+
}
6405+
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
6406+
6407+
llvm::SmallVector<QualType, 16> FieldTypes;
6408+
if (CXXD && CXXD->isStandardLayout())
6409+
Record = CXXD->getStandardLayoutBaseWithFields();
6410+
6411+
// deal with potential base classes
6412+
if (CXXD && !CXXD->isStandardLayout()) {
6413+
for (auto &Base : CXXD->bases())
6414+
FieldTypes.push_back(Base.getType());
6415+
}
6416+
6417+
for (auto *FD : Record->fields())
6418+
FieldTypes.push_back(FD->getType());
6419+
6420+
for (int64_t i = FieldTypes.size() - 1; i > -1; i--) {
6421+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6422+
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, i));
6423+
WorkList.insert(WorkList.end(), {FieldTypes[i], IdxListCopy});
6424+
}
6425+
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
6426+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6427+
CharUnits Align = getContext().getTypeAlignInChars(T);
6428+
Address GEP =
6429+
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
6430+
for (unsigned i = 0; i < VT->getNumElements(); i++) {
6431+
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
6432+
// gep on vector fields is not recommended so combine gep with
6433+
// extract/insert
6434+
AccessList.push_back({GEP, Idx});
6435+
FlatTypes.push_back(VT->getElementType());
6436+
}
6437+
} else {
6438+
// a scalar/builtin type
6439+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6440+
CharUnits Align = getContext().getTypeAlignInChars(T);
6441+
Address GEP =
6442+
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
6443+
AccessList.push_back({GEP, NULL});
6444+
FlatTypes.push_back(T);
64126445
}
6413-
} else { // should be a scalar should we assert or check?
6414-
// create a gep
6415-
llvm::Type *Ty = ConvertTypeForMem(SrcTy);
6416-
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
6417-
Address GEP = Builder.CreateInBoundsGEP(Val, IdxList, Ty, Align, "gep");
6418-
GEPList.push_back({GEP, NULL});
6419-
FlatTypes.push_back(SrcTy);
6420-
}
6421-
// target extension types?
6446+
}
64226447
}

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,14 +496,10 @@ static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
496496
QualType DestTy, llvm::Value *SrcVal,
497497
QualType SrcTy, SourceLocation Loc) {
498498
// Flatten our destination
499-
SmallVector<QualType> DestTypes; // Flattened type
500-
SmallVector<llvm::Value *, 4> IdxList;
501-
IdxList.push_back(
502-
llvm::ConstantInt::get(llvm::IntegerType::get(CGF.getLLVMContext(), 32),
503-
0)); // because an Address is a pointer
499+
SmallVector<QualType, 16> DestTypes; // Flattened type
504500
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
505501
// ^^ Flattened accesses to DestVal we want to store into
506-
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
502+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
507503

508504
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
509505
SrcTy = VT->getElementType();
@@ -536,23 +532,15 @@ static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
536532
QualType DestTy, Address SrcVal,
537533
QualType SrcTy, SourceLocation Loc) {
538534
// Flatten our destination
539-
SmallVector<QualType> DestTypes; // Flattened type
540-
SmallVector<llvm::Value *, 4> IdxList;
541-
IdxList.push_back(
542-
llvm::ConstantInt::get(llvm::IntegerType::get(CGF.getLLVMContext(), 32),
543-
0)); // Because an Address is a pointer
535+
SmallVector<QualType, 16> DestTypes; // Flattened type
544536
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
545537
// ^^ Flattened accesses to DestVal we want to store into
546-
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
538+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
547539
// Flatten our src
548-
SmallVector<QualType> SrcTypes; // Flattened type
540+
SmallVector<QualType, 16> SrcTypes; // Flattened type
549541
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
550542
// ^^ Flattened accesses to SrcVal we want to load from
551-
IdxList.clear();
552-
IdxList.push_back(
553-
llvm::ConstantInt::get(llvm::IntegerType::get(CGF.getLLVMContext(), 32),
554-
0)); // Because an Address is a pointer
555-
CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);
543+
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
556544

557545
assert(StoreGEPList.size() <= LoadGEPList.size() &&
558546
"Cannot perform HLSL flat cast when flattened source object \

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,13 +2266,9 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
22662266
static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
22672267
QualType RHSTy, QualType LHSTy,
22682268
SourceLocation Loc) {
2269-
SmallVector<llvm::Value *, 4> IdxList;
2270-
IdxList.push_back(
2271-
llvm::ConstantInt::get(llvm::IntegerType::get(CGF.getLLVMContext(), 32),
2272-
0)); // because an Address is a pointer
22732269
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
2274-
SmallVector<QualType> SrcTypes; // Flattened type
2275-
CGF.FlattenAccessAndType(RHSVal, RHSTy, IdxList, LoadGEPList, SrcTypes);
2270+
SmallVector<QualType, 16> SrcTypes; // Flattened type
2271+
CGF.FlattenAccessAndType(RHSVal, RHSTy, LoadGEPList, SrcTypes);
22762272
// LHS is either a vector or a builtin?
22772273
// if its a vector create a temp alloca to store into and return that
22782274
if (auto *VecTy = LHSTy->getAs<VectorType>()) {

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4360,9 +4360,9 @@ class CodeGenFunction : public CodeGenTypeCache {
43604360
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
43614361

43624362
void FlattenAccessAndType(
4363-
Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
4364-
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
4365-
SmallVector<QualType> &FlatTypes);
4363+
Address Addr, QualType AddrTy,
4364+
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
4365+
SmallVectorImpl<QualType> &FlatTypes);
43664366

43674367
llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
43684368
const ObjCIvarDecl *Ivar);

0 commit comments

Comments
 (0)