Skip to content

Commit 34d9537

Browse files
committed
flatten LValues instead of addresses to reuse existing code to enable handling bitfields; update tests.
1 parent f8bb4f9 commit 34d9537

File tree

11 files changed

+384
-198
lines changed

11 files changed

+384
-198
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6784,74 +6784,104 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
67846784
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
67856785
}
67866786

6787-
void CodeGenFunction::FlattenAccessAndType(
6788-
Address Addr, QualType AddrType,
6789-
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
6790-
SmallVectorImpl<QualType> &FlatTypes) {
6791-
// WorkList is list of type we are processing + the Index List to access
6792-
// the field of that type in Addr for use in a GEP
6793-
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
6794-
16>
6787+
void CodeGenFunction::FlattenAccessAndTypeLValue(
6788+
LValue Val, SmallVectorImpl<LValue> &AccessList) {
6789+
6790+
llvm::SmallVector<
6791+
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
67956792
WorkList;
67966793
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
6797-
// Addr should be a pointer so we need to 'dereference' it
6798-
WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});
6794+
WorkList.push_back({Val, Val.getType(), {llvm::ConstantInt::get(IdxTy, 0)}});
67996795

68006796
while (!WorkList.empty()) {
6801-
auto [T, IdxList] = WorkList.pop_back_val();
6797+
auto [LVal, T, IdxList] = WorkList.pop_back_val();
68026798
T = T.getCanonicalType().getUnqualifiedType();
68036799
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
6800+
68046801
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
68056802
uint64_t Size = CAT->getZExtSize();
68066803
for (int64_t I = Size - 1; I > -1; I--) {
68076804
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
68086805
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6809-
WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
6806+
WorkList.push_back({LVal, CAT->getElementType(), IdxListCopy});
68106807
}
68116808
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
68126809
const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
68136810
assert(!Record->isUnion() && "Union types not supported in flat cast.");
68146811

68156812
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
68166813

6817-
llvm::SmallVector<QualType, 16> FieldTypes;
6814+
llvm::SmallVector<
6815+
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
6816+
ReverseList;
68186817
if (CXXD && CXXD->isStandardLayout())
68196818
Record = CXXD->getStandardLayoutBaseWithFields();
68206819

68216820
// deal with potential base classes
68226821
if (CXXD && !CXXD->isStandardLayout()) {
6823-
for (auto &Base : CXXD->bases())
6824-
FieldTypes.push_back(Base.getType());
6822+
if (CXXD->getNumBases() > 0) {
6823+
assert(CXXD->getNumBases() == 1 &&
6824+
"HLSL doesn't support multiple inheritance.");
6825+
auto Base = CXXD->bases_begin();
6826+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6827+
IdxListCopy.push_back(llvm::ConstantInt::get(
6828+
IdxTy, 0)); // base struct should be at index zero
6829+
ReverseList.insert(ReverseList.end(),
6830+
{LVal, Base->getType(), IdxListCopy});
6831+
}
68256832
}
68266833

6827-
for (auto *FD : Record->fields())
6828-
FieldTypes.push_back(FD->getType());
6834+
const CGRecordLayout &Layout = CGM.getTypes().getCGRecordLayout(Record);
68296835

6830-
for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
6831-
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6832-
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6833-
WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
6836+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6837+
CharUnits Align = getContext().getTypeAlignInChars(T);
6838+
LValue RLValue;
6839+
bool createdGEP = false;
6840+
for (auto *FD : Record->fields()) {
6841+
if (FD->isBitField()) {
6842+
if (FD->isUnnamedBitField())
6843+
continue;
6844+
if (!createdGEP) {
6845+
createdGEP = true;
6846+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
6847+
LLVMT, Align, "gep");
6848+
RLValue = MakeAddrLValue(GEP, T);
6849+
}
6850+
LValue FieldLVal = EmitLValueForField(RLValue, FD, true);
6851+
ReverseList.insert(ReverseList.end(), {FieldLVal, FD->getType(), {}});
6852+
} else {
6853+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6854+
IdxListCopy.push_back(
6855+
llvm::ConstantInt::get(IdxTy, Layout.getLLVMFieldNo(FD)));
6856+
ReverseList.insert(ReverseList.end(),
6857+
{LVal, FD->getType(), IdxListCopy});
6858+
}
68346859
}
6860+
6861+
std::reverse(ReverseList.begin(), ReverseList.end());
6862+
llvm::append_range(WorkList, ReverseList);
68356863
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
68366864
llvm::Type *LLVMT = ConvertTypeForMem(T);
68376865
CharUnits Align = getContext().getTypeAlignInChars(T);
6838-
Address GEP =
6839-
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
6866+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList, LLVMT,
6867+
Align, "vector.gep");
6868+
LValue Base = MakeAddrLValue(GEP, T);
68406869
for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
6841-
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
6842-
// gep on vector fields is not recommended so combine gep with
6843-
// extract/insert
6844-
AccessList.emplace_back(GEP, Idx);
6845-
FlatTypes.push_back(VT->getElementType());
6870+
llvm::Constant *Idx = llvm::ConstantInt::get(IdxTy, I);
6871+
LValue LV =
6872+
LValue::MakeVectorElt(Base.getAddress(), Idx, VT->getElementType(),
6873+
Base.getBaseInfo(), TBAAAccessInfo());
6874+
AccessList.emplace_back(LV);
68466875
}
6847-
} else {
6848-
// a scalar/builtin type
6849-
llvm::Type *LLVMT = ConvertTypeForMem(T);
6850-
CharUnits Align = getContext().getTypeAlignInChars(T);
6851-
Address GEP =
6852-
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
6853-
AccessList.emplace_back(GEP, nullptr);
6854-
FlatTypes.push_back(T);
6876+
} else { // a scalar/builtin type
6877+
if (!IdxList.empty()) {
6878+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6879+
CharUnits Align = getContext().getTypeAlignInChars(T);
6880+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
6881+
LLVMT, Align, "gep");
6882+
AccessList.emplace_back(MakeAddrLValue(GEP, T));
6883+
} else // must be a bitfield we already created an lvalue for
6884+
AccessList.emplace_back(LVal);
68556885
}
68566886
}
68576887
}

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 55 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -488,100 +488,62 @@ static bool isTrivialFiller(Expr *E) {
488488
return false;
489489
}
490490

491-
static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
492-
QualType DestTy, llvm::Value *SrcVal,
493-
QualType SrcTy, SourceLocation Loc) {
491+
// emit an elementwise cast where the RHS is a scalar or vector
492+
// or emit an aggregate splat cast
493+
static void EmitHLSLScalarElementwiseAndSplatCasts(CodeGenFunction &CGF,
494+
LValue DestVal,
495+
llvm::Value *SrcVal,
496+
QualType SrcTy,
497+
SourceLocation Loc) {
494498
// Flatten our destination
495-
SmallVector<QualType> DestTypes; // Flattened type
496-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
497-
// ^^ Flattened accesses to DestVal we want to store into
498-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
499-
500-
assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
501-
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
502-
llvm::Value *Cast =
503-
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
504-
505-
// store back
506-
llvm::Value *Idx = StoreGEPList[I].second;
507-
if (Idx) {
508-
llvm::Value *V =
509-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
510-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
511-
}
512-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
513-
}
514-
}
515-
516-
// emit a flat cast where the RHS is a scalar, including vector
517-
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
518-
QualType DestTy, llvm::Value *SrcVal,
519-
QualType SrcTy, SourceLocation Loc) {
520-
// Flatten our destination
521-
SmallVector<QualType, 16> DestTypes; // Flattened type
522-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
523-
// ^^ Flattened accesses to DestVal we want to store into
524-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
525-
526-
assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
527-
const VectorType *VT = SrcTy->getAs<VectorType>();
528-
SrcTy = VT->getElementType();
529-
assert(StoreGEPList.size() <= VT->getNumElements() &&
530-
"Cannot perform HLSL flat cast when vector source \
531-
object has less elements than flattened destination \
532-
object.");
533-
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
534-
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
499+
SmallVector<LValue, 16> StoreList;
500+
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
501+
502+
bool isVector = false;
503+
if (auto *VT = SrcTy->getAs<VectorType>()) {
504+
isVector = true;
505+
SrcTy = VT->getElementType();
506+
assert(StoreList.size() <= VT->getNumElements() &&
507+
"Cannot perform HLSL flat cast when vector source \
508+
object has less elements than flattened destination \
509+
object.");
510+
}
511+
512+
for (unsigned I = 0, Size = StoreList.size(); I < Size; I++) {
513+
LValue DestLVal = StoreList[I];
514+
llvm::Value *Load =
515+
isVector ? CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load")
516+
: SrcVal;
535517
llvm::Value *Cast =
536-
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);
537-
538-
// store back
539-
llvm::Value *Idx = StoreGEPList[I].second;
540-
if (Idx) {
541-
llvm::Value *V =
542-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
543-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
544-
}
545-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
518+
CGF.EmitScalarConversion(Load, SrcTy, DestLVal.getType(), Loc);
519+
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
546520
}
547521
}
548522

549523
// emit a flat cast where the RHS is an aggregate
550-
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
551-
QualType DestTy, Address SrcVal,
552-
QualType SrcTy, SourceLocation Loc) {
524+
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue DestVal,
525+
LValue SrcVal, SourceLocation Loc) {
553526
// Flatten our destination
554-
SmallVector<QualType, 16> DestTypes; // Flattened type
555-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
556-
// ^^ Flattened accesses to DestVal we want to store into
557-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
527+
SmallVector<LValue, 16> StoreList;
528+
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
558529
// Flatten our src
559-
SmallVector<QualType, 16> SrcTypes; // Flattened type
560-
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
561-
// ^^ Flattened accesses to SrcVal we want to load from
562-
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
530+
SmallVector<LValue, 16> LoadList;
531+
CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
563532

564-
assert(StoreGEPList.size() <= LoadGEPList.size() &&
565-
"Cannot perform HLSL flat cast when flattened source object \
533+
assert(StoreList.size() <= LoadList.size() &&
534+
"Cannot perform HLSL elementwise cast when flattened source object \
566535
has less elements than flattened destination object.");
567-
// apply casts to what we load from LoadGEPList
536+
// apply casts to what we load from LoadList
568537
// and store result in Dest
569-
for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
570-
llvm::Value *Idx = LoadGEPList[I].second;
571-
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
572-
Load =
573-
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
574-
llvm::Value *Cast =
575-
CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);
576-
577-
// store back
578-
Idx = StoreGEPList[I].second;
579-
if (Idx) {
580-
llvm::Value *V =
581-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
582-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
583-
}
584-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
538+
for (unsigned I = 0, E = StoreList.size(); I < E; I++) {
539+
LValue DestLVal = StoreList[I];
540+
LValue SrcLVal = LoadList[I];
541+
RValue RVal = CGF.EmitLoadOfLValue(SrcLVal, Loc);
542+
assert(RVal.isScalar() && "All flattened source values should be scalars");
543+
llvm::Value *Val = RVal.getScalarVal();
544+
llvm::Value *Cast = CGF.EmitScalarConversion(Val, SrcLVal.getType(),
545+
DestLVal.getType(), Loc);
546+
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
585547
}
586548
}
587549

@@ -988,31 +950,33 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
988950
Expr *Src = E->getSubExpr();
989951
QualType SrcTy = Src->getType();
990952
RValue RV = CGF.EmitAnyExpr(Src);
991-
QualType DestTy = E->getType();
992-
Address DestVal = Dest.getAddress();
953+
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
993954
SourceLocation Loc = E->getExprLoc();
994955

995-
assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
956+
assert(RV.isScalar() && SrcTy->isScalarType() &&
957+
"RHS of HLSL splat cast must be a scalar.");
996958
llvm::Value *SrcVal = RV.getScalarVal();
997-
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
959+
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
998960
break;
999961
}
1000962
case CK_HLSLElementwiseCast: {
1001963
Expr *Src = E->getSubExpr();
1002964
QualType SrcTy = Src->getType();
1003965
RValue RV = CGF.EmitAnyExpr(Src);
1004-
QualType DestTy = E->getType();
1005-
Address DestVal = Dest.getAddress();
966+
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
1006967
SourceLocation Loc = E->getExprLoc();
1007968

1008969
if (RV.isScalar()) {
1009970
llvm::Value *SrcVal = RV.getScalarVal();
1010-
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
971+
assert(SrcTy->isVectorType() &&
972+
"HLSL Elementwise cast doesn't handle splatting.");
973+
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
1011974
} else {
1012975
assert(RV.isAggregate() &&
1013976
"Can't perform HLSL Aggregate cast on a complex type.");
1014977
Address SrcVal = RV.getAggregateAddress();
1015-
EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
978+
EmitHLSLElementwiseCast(CGF, DestLVal, CGF.MakeAddrLValue(SrcVal, SrcTy),
979+
Loc);
1016980
}
1017981
break;
1018982
}

0 commit comments

Comments
 (0)