Skip to content

Commit 35c57a7

Browse files
authored
[HLSL] Add support for elementwise and aggregate splat casting struct types with bitfields (#161263)
Adds support for elementwise and aggregate splat casting struct types with bitfields. Replacing existing Flattening function which used to produce a list of GEPs representing a flattened object with one that produces a list of LValues representing a flattened object. The LValues can be used by EmitStoreThroughLValue and EmitLoadOfLValue, ensuring bitfields are properly loaded and stored. This also simplifies the code in the elementwise and aggregate splat casting functions. Closes #125986
1 parent 45c4124 commit 35c57a7

File tree

11 files changed

+489
-208
lines changed

11 files changed

+489
-208
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6784,74 +6784,102 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
67846784
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
67856785
}
67866786

6787-
void CodeGenFunction::FlattenAccessAndType(
6788-
Address Addr, QualType AddrType,
6789-
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
6790-
SmallVectorImpl<QualType> &FlatTypes) {
6791-
// WorkList is list of type we are processing + the Index List to access
6792-
// the field of that type in Addr for use in a GEP
6793-
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
6794-
16>
6787+
void CodeGenFunction::FlattenAccessAndTypeLValue(
6788+
LValue Val, SmallVectorImpl<LValue> &AccessList) {
6789+
6790+
llvm::SmallVector<
6791+
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
67956792
WorkList;
67966793
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
6797-
// Addr should be a pointer so we need to 'dereference' it
6798-
WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});
6794+
WorkList.push_back({Val, Val.getType(), {llvm::ConstantInt::get(IdxTy, 0)}});
67996795

68006796
while (!WorkList.empty()) {
6801-
auto [T, IdxList] = WorkList.pop_back_val();
6797+
auto [LVal, T, IdxList] = WorkList.pop_back_val();
68026798
T = T.getCanonicalType().getUnqualifiedType();
68036799
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
6800+
68046801
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
68056802
uint64_t Size = CAT->getZExtSize();
68066803
for (int64_t I = Size - 1; I > -1; I--) {
68076804
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
68086805
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6809-
WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
6806+
WorkList.emplace_back(LVal, CAT->getElementType(), IdxListCopy);
68106807
}
68116808
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
68126809
const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
68136810
assert(!Record->isUnion() && "Union types not supported in flat cast.");
68146811

68156812
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
68166813

6817-
llvm::SmallVector<QualType, 16> FieldTypes;
6814+
llvm::SmallVector<
6815+
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
6816+
ReverseList;
68186817
if (CXXD && CXXD->isStandardLayout())
68196818
Record = CXXD->getStandardLayoutBaseWithFields();
68206819

68216820
// deal with potential base classes
68226821
if (CXXD && !CXXD->isStandardLayout()) {
6823-
for (auto &Base : CXXD->bases())
6824-
FieldTypes.push_back(Base.getType());
6822+
if (CXXD->getNumBases() > 0) {
6823+
assert(CXXD->getNumBases() == 1 &&
6824+
"HLSL doesn't support multiple inheritance.");
6825+
auto Base = CXXD->bases_begin();
6826+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6827+
IdxListCopy.push_back(llvm::ConstantInt::get(
6828+
IdxTy, 0)); // base struct should be at index zero
6829+
ReverseList.emplace_back(LVal, Base->getType(), IdxListCopy);
6830+
}
68256831
}
68266832

6827-
for (auto *FD : Record->fields())
6828-
FieldTypes.push_back(FD->getType());
6833+
const CGRecordLayout &Layout = CGM.getTypes().getCGRecordLayout(Record);
68296834

6830-
for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
6831-
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6832-
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6833-
WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
6835+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6836+
CharUnits Align = getContext().getTypeAlignInChars(T);
6837+
LValue RLValue;
6838+
bool createdGEP = false;
6839+
for (auto *FD : Record->fields()) {
6840+
if (FD->isBitField()) {
6841+
if (FD->isUnnamedBitField())
6842+
continue;
6843+
if (!createdGEP) {
6844+
createdGEP = true;
6845+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
6846+
LLVMT, Align, "gep");
6847+
RLValue = MakeAddrLValue(GEP, T);
6848+
}
6849+
LValue FieldLVal = EmitLValueForField(RLValue, FD, true);
6850+
ReverseList.push_back({FieldLVal, FD->getType(), {}});
6851+
} else {
6852+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6853+
IdxListCopy.push_back(
6854+
llvm::ConstantInt::get(IdxTy, Layout.getLLVMFieldNo(FD)));
6855+
ReverseList.emplace_back(LVal, FD->getType(), IdxListCopy);
6856+
}
68346857
}
6858+
6859+
std::reverse(ReverseList.begin(), ReverseList.end());
6860+
llvm::append_range(WorkList, ReverseList);
68356861
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
68366862
llvm::Type *LLVMT = ConvertTypeForMem(T);
68376863
CharUnits Align = getContext().getTypeAlignInChars(T);
6838-
Address GEP =
6839-
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
6864+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList, LLVMT,
6865+
Align, "vector.gep");
6866+
LValue Base = MakeAddrLValue(GEP, T);
68406867
for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
6841-
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
6842-
// gep on vector fields is not recommended so combine gep with
6843-
// extract/insert
6844-
AccessList.emplace_back(GEP, Idx);
6845-
FlatTypes.push_back(VT->getElementType());
6868+
llvm::Constant *Idx = llvm::ConstantInt::get(IdxTy, I);
6869+
LValue LV =
6870+
LValue::MakeVectorElt(Base.getAddress(), Idx, VT->getElementType(),
6871+
Base.getBaseInfo(), TBAAAccessInfo());
6872+
AccessList.emplace_back(LV);
68466873
}
6847-
} else {
6848-
// a scalar/builtin type
6849-
llvm::Type *LLVMT = ConvertTypeForMem(T);
6850-
CharUnits Align = getContext().getTypeAlignInChars(T);
6851-
Address GEP =
6852-
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
6853-
AccessList.emplace_back(GEP, nullptr);
6854-
FlatTypes.push_back(T);
6874+
} else { // a scalar/builtin type
6875+
if (!IdxList.empty()) {
6876+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6877+
CharUnits Align = getContext().getTypeAlignInChars(T);
6878+
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
6879+
LLVMT, Align, "gep");
6880+
AccessList.emplace_back(MakeAddrLValue(GEP, T));
6881+
} else // must be a bitfield we already created an lvalue for
6882+
AccessList.emplace_back(LVal);
68556883
}
68566884
}
68576885
}

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 55 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -488,100 +488,62 @@ static bool isTrivialFiller(Expr *E) {
488488
return false;
489489
}
490490

491-
static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
492-
QualType DestTy, llvm::Value *SrcVal,
493-
QualType SrcTy, SourceLocation Loc) {
491+
// emit an elementwise cast where the RHS is a scalar or vector
492+
// or emit an aggregate splat cast
493+
static void EmitHLSLScalarElementwiseAndSplatCasts(CodeGenFunction &CGF,
494+
LValue DestVal,
495+
llvm::Value *SrcVal,
496+
QualType SrcTy,
497+
SourceLocation Loc) {
494498
// Flatten our destination
495-
SmallVector<QualType> DestTypes; // Flattened type
496-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
497-
// ^^ Flattened accesses to DestVal we want to store into
498-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
499-
500-
assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
501-
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
502-
llvm::Value *Cast =
503-
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
504-
505-
// store back
506-
llvm::Value *Idx = StoreGEPList[I].second;
507-
if (Idx) {
508-
llvm::Value *V =
509-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
510-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
511-
}
512-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
513-
}
514-
}
515-
516-
// emit a flat cast where the RHS is a scalar, including vector
517-
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
518-
QualType DestTy, llvm::Value *SrcVal,
519-
QualType SrcTy, SourceLocation Loc) {
520-
// Flatten our destination
521-
SmallVector<QualType, 16> DestTypes; // Flattened type
522-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
523-
// ^^ Flattened accesses to DestVal we want to store into
524-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
525-
526-
assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
527-
const VectorType *VT = SrcTy->getAs<VectorType>();
528-
SrcTy = VT->getElementType();
529-
assert(StoreGEPList.size() <= VT->getNumElements() &&
530-
"Cannot perform HLSL flat cast when vector source \
531-
object has less elements than flattened destination \
532-
object.");
533-
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
534-
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
499+
SmallVector<LValue, 16> StoreList;
500+
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
501+
502+
bool isVector = false;
503+
if (auto *VT = SrcTy->getAs<VectorType>()) {
504+
isVector = true;
505+
SrcTy = VT->getElementType();
506+
assert(StoreList.size() <= VT->getNumElements() &&
507+
"Cannot perform HLSL flat cast when vector source \
508+
object has less elements than flattened destination \
509+
object.");
510+
}
511+
512+
for (unsigned I = 0, Size = StoreList.size(); I < Size; I++) {
513+
LValue DestLVal = StoreList[I];
514+
llvm::Value *Load =
515+
isVector ? CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load")
516+
: SrcVal;
535517
llvm::Value *Cast =
536-
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);
537-
538-
// store back
539-
llvm::Value *Idx = StoreGEPList[I].second;
540-
if (Idx) {
541-
llvm::Value *V =
542-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
543-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
544-
}
545-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
518+
CGF.EmitScalarConversion(Load, SrcTy, DestLVal.getType(), Loc);
519+
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
546520
}
547521
}
548522

549523
// emit a flat cast where the RHS is an aggregate
550-
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
551-
QualType DestTy, Address SrcVal,
552-
QualType SrcTy, SourceLocation Loc) {
524+
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue DestVal,
525+
LValue SrcVal, SourceLocation Loc) {
553526
// Flatten our destination
554-
SmallVector<QualType, 16> DestTypes; // Flattened type
555-
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
556-
// ^^ Flattened accesses to DestVal we want to store into
557-
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
527+
SmallVector<LValue, 16> StoreList;
528+
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
558529
// Flatten our src
559-
SmallVector<QualType, 16> SrcTypes; // Flattened type
560-
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
561-
// ^^ Flattened accesses to SrcVal we want to load from
562-
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
530+
SmallVector<LValue, 16> LoadList;
531+
CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
563532

564-
assert(StoreGEPList.size() <= LoadGEPList.size() &&
565-
"Cannot perform HLSL flat cast when flattened source object \
533+
assert(StoreList.size() <= LoadList.size() &&
534+
"Cannot perform HLSL elementwise cast when flattened source object \
566535
has less elements than flattened destination object.");
567-
// apply casts to what we load from LoadGEPList
536+
// apply casts to what we load from LoadList
568537
// and store result in Dest
569-
for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
570-
llvm::Value *Idx = LoadGEPList[I].second;
571-
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
572-
Load =
573-
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
574-
llvm::Value *Cast =
575-
CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);
576-
577-
// store back
578-
Idx = StoreGEPList[I].second;
579-
if (Idx) {
580-
llvm::Value *V =
581-
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
582-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
583-
}
584-
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
538+
for (unsigned I = 0, E = StoreList.size(); I < E; I++) {
539+
LValue DestLVal = StoreList[I];
540+
LValue SrcLVal = LoadList[I];
541+
RValue RVal = CGF.EmitLoadOfLValue(SrcLVal, Loc);
542+
assert(RVal.isScalar() && "All flattened source values should be scalars");
543+
llvm::Value *Val = RVal.getScalarVal();
544+
llvm::Value *Cast = CGF.EmitScalarConversion(Val, SrcLVal.getType(),
545+
DestLVal.getType(), Loc);
546+
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
585547
}
586548
}
587549

@@ -988,31 +950,33 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
988950
Expr *Src = E->getSubExpr();
989951
QualType SrcTy = Src->getType();
990952
RValue RV = CGF.EmitAnyExpr(Src);
991-
QualType DestTy = E->getType();
992-
Address DestVal = Dest.getAddress();
953+
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
993954
SourceLocation Loc = E->getExprLoc();
994955

995-
assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
956+
assert(RV.isScalar() && SrcTy->isScalarType() &&
957+
"RHS of HLSL splat cast must be a scalar.");
996958
llvm::Value *SrcVal = RV.getScalarVal();
997-
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
959+
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
998960
break;
999961
}
1000962
case CK_HLSLElementwiseCast: {
1001963
Expr *Src = E->getSubExpr();
1002964
QualType SrcTy = Src->getType();
1003965
RValue RV = CGF.EmitAnyExpr(Src);
1004-
QualType DestTy = E->getType();
1005-
Address DestVal = Dest.getAddress();
966+
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
1006967
SourceLocation Loc = E->getExprLoc();
1007968

1008969
if (RV.isScalar()) {
1009970
llvm::Value *SrcVal = RV.getScalarVal();
1010-
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
971+
assert(SrcTy->isVectorType() &&
972+
"HLSL Elementwise cast doesn't handle splatting.");
973+
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
1011974
} else {
1012975
assert(RV.isAggregate() &&
1013976
"Can't perform HLSL Aggregate cast on a complex type.");
1014977
Address SrcVal = RV.getAggregateAddress();
1015-
EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
978+
EmitHLSLElementwiseCast(CGF, DestLVal, CGF.MakeAddrLValue(SrcVal, SrcTy),
979+
Loc);
1016980
}
1017981
break;
1018982
}

0 commit comments

Comments
 (0)