Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 64 additions & 36 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6784,74 +6784,102 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
}

void CodeGenFunction::FlattenAccessAndType(
Address Addr, QualType AddrType,
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
SmallVectorImpl<QualType> &FlatTypes) {
// WorkList is list of type we are processing + the Index List to access
// the field of that type in Addr for use in a GEP
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
16>
void CodeGenFunction::FlattenAccessAndTypeLValue(
LValue Val, SmallVectorImpl<LValue> &AccessList) {

llvm::SmallVector<
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
WorkList;
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
// Addr should be a pointer so we need to 'dereference' it
WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});
WorkList.push_back({Val, Val.getType(), {llvm::ConstantInt::get(IdxTy, 0)}});

while (!WorkList.empty()) {
auto [T, IdxList] = WorkList.pop_back_val();
auto [LVal, T, IdxList] = WorkList.pop_back_val();
T = T.getCanonicalType().getUnqualifiedType();
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");

if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
uint64_t Size = CAT->getZExtSize();
for (int64_t I = Size - 1; I > -1; I--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
WorkList.emplace_back(LVal, CAT->getElementType(), IdxListCopy);
}
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
assert(!Record->isUnion() && "Union types not supported in flat cast.");

const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);

llvm::SmallVector<QualType, 16> FieldTypes;
llvm::SmallVector<
std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
ReverseList;
if (CXXD && CXXD->isStandardLayout())
Record = CXXD->getStandardLayoutBaseWithFields();

// deal with potential base classes
if (CXXD && !CXXD->isStandardLayout()) {
for (auto &Base : CXXD->bases())
FieldTypes.push_back(Base.getType());
if (CXXD->getNumBases() > 0) {
assert(CXXD->getNumBases() == 1 &&
"HLSL doesn't support multiple inheritance.");
auto Base = CXXD->bases_begin();
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(
IdxTy, 0)); // base struct should be at index zero
ReverseList.emplace_back(LVal, Base->getType(), IdxListCopy);
}
}

for (auto *FD : Record->fields())
FieldTypes.push_back(FD->getType());
const CGRecordLayout &Layout = CGM.getTypes().getCGRecordLayout(Record);

for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
LValue RLValue;
bool createdGEP = false;
for (auto *FD : Record->fields()) {
if (FD->isBitField()) {
if (FD->isUnnamedBitField())
continue;
if (!createdGEP) {
createdGEP = true;
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
LLVMT, Align, "gep");
Comment on lines +6843 to +6846
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see (in a test) how it looks like when there are multiple fields with bitfield annotation, when does the GEP get skipped. And what if the set of bitfields span multiple i32?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what if you have a set of bitfields, a regular field, and then set of bitfields? Should the createdGEP flag be reset when you find a regular field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it should only create a gep once, since its for the struct itself. I'll add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gep should be skipped if it is a regular field and not a bitfield. EmitLValueForField should only be used if it is a bitfield and we need to construct a special LValue for it; this requires the GEP for the struct itself. Otherwise we wait as long as possible to generate the gep to access a 'scalar' field, because we want to generate as few geps as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hekota I added a test to try and address your questions. Let me know if it looks sufficient to you.

RLValue = MakeAddrLValue(GEP, T);
}
LValue FieldLVal = EmitLValueForField(RLValue, FD, true);
ReverseList.push_back({FieldLVal, FD->getType(), {}});
} else {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(
llvm::ConstantInt::get(IdxTy, Layout.getLLVMFieldNo(FD)));
ReverseList.emplace_back(LVal, FD->getType(), IdxListCopy);
}
}

std::reverse(ReverseList.begin(), ReverseList.end());
llvm::append_range(WorkList, ReverseList);
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList, LLVMT,
Align, "vector.gep");
LValue Base = MakeAddrLValue(GEP, T);
for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
// gep on vector fields is not recommended so combine gep with
// extract/insert
AccessList.emplace_back(GEP, Idx);
FlatTypes.push_back(VT->getElementType());
llvm::Constant *Idx = llvm::ConstantInt::get(IdxTy, I);
LValue LV =
LValue::MakeVectorElt(Base.getAddress(), Idx, VT->getElementType(),
Base.getBaseInfo(), TBAAAccessInfo());
AccessList.emplace_back(LV);
}
} else {
// a scalar/builtin type
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
AccessList.emplace_back(GEP, nullptr);
FlatTypes.push_back(T);
} else { // a scalar/builtin type
if (!IdxList.empty()) {
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
LLVMT, Align, "gep");
AccessList.emplace_back(MakeAddrLValue(GEP, T));
} else // must be a bitfield we already created an lvalue for
AccessList.emplace_back(LVal);
}
}
}
146 changes: 55 additions & 91 deletions clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,100 +488,62 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// emit an elementwise cast where the RHS is a scalar or vector
// or emit an aggregate splat cast
static void EmitHLSLScalarElementwiseAndSplatCasts(CodeGenFunction &CGF,
LValue DestVal,
llvm::Value *SrcVal,
QualType SrcTy,
SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
llvm::Value *Cast =
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);

// store back
llvm::Value *Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
}
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
const VectorType *VT = SrcTy->getAs<VectorType>();
SrcTy = VT->getElementType();
assert(StoreGEPList.size() <= VT->getNumElements() &&
"Cannot perform HLSL flat cast when vector source \
object has less elements than flattened destination \
object.");
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
SmallVector<LValue, 16> StoreList;
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);

bool isVector = false;
if (auto *VT = SrcTy->getAs<VectorType>()) {
isVector = true;
SrcTy = VT->getElementType();
assert(StoreList.size() <= VT->getNumElements() &&
"Cannot perform HLSL flat cast when vector source \
object has less elements than flattened destination \
object.");
}

for (unsigned I = 0, Size = StoreList.size(); I < Size; I++) {
LValue DestLVal = StoreList[I];
llvm::Value *Load =
isVector ? CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load")
: SrcVal;
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);

// store back
llvm::Value *Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
CGF.EmitScalarConversion(Load, SrcTy, DestLVal.getType(), Loc);
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
}
}

// emit a flat cast where the RHS is an aggregate
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, Address SrcVal,
QualType SrcTy, SourceLocation Loc) {
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue DestVal,
LValue SrcVal, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
SmallVector<LValue, 16> StoreList;
CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
// Flatten our src
SmallVector<QualType, 16> SrcTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
// ^^ Flattened accesses to SrcVal we want to load from
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
SmallVector<LValue, 16> LoadList;
CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);

assert(StoreGEPList.size() <= LoadGEPList.size() &&
"Cannot perform HLSL flat cast when flattened source object \
assert(StoreList.size() <= LoadList.size() &&
"Cannot perform HLSL elementwise cast when flattened source object \
has less elements than flattened destination object.");
// apply casts to what we load from LoadGEPList
// apply casts to what we load from LoadList
// and store result in Dest
for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
llvm::Value *Idx = LoadGEPList[I].second;
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
Load =
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);

// store back
Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
for (unsigned I = 0, E = StoreList.size(); I < E; I++) {
LValue DestLVal = StoreList[I];
LValue SrcLVal = LoadList[I];
RValue RVal = CGF.EmitLoadOfLValue(SrcLVal, Loc);
assert(RVal.isScalar() && "All flattened source values should be scalars");
llvm::Value *Val = RVal.getScalarVal();
llvm::Value *Cast = CGF.EmitScalarConversion(Val, SrcLVal.getType(),
DestLVal.getType(), Loc);
CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
}
}

Expand Down Expand Up @@ -988,31 +950,33 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
SourceLocation Loc = E->getExprLoc();

assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
assert(RV.isScalar() && SrcTy->isScalarType() &&
"RHS of HLSL splat cast must be a scalar.");
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
break;
}
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
SourceLocation Loc = E->getExprLoc();

if (RV.isScalar()) {
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
assert(SrcTy->isVectorType() &&
"HLSL Elementwise cast doesn't handle splatting.");
EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
} else {
assert(RV.isAggregate() &&
"Can't perform HLSL Aggregate cast on a complex type.");
Address SrcVal = RV.getAggregateAddress();
EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
EmitHLSLElementwiseCast(CGF, DestLVal, CGF.MakeAddrLValue(SrcVal, SrcTy),
Loc);
}
break;
}
Expand Down
Loading