Skip to content

Commit fa0e9c9

Browse files
committed
respond to pr comments
1 parent 162c2b5 commit fa0e9c9

File tree

4 files changed

+40
-42
lines changed

4 files changed

+40
-42
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6385,23 +6385,15 @@ void CodeGenFunction::FlattenAccessAndType(
63856385
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
63866386
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
63876387
uint64_t Size = CAT->getZExtSize();
6388-
for (int64_t i = Size - 1; i > -1; i--) {
6388+
for (int64_t I = Size - 1; I > -1; I--) {
63896389
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6390-
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, i));
6390+
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
63916391
WorkList.insert(WorkList.end(), {CAT->getElementType(), IdxListCopy});
63926392
}
63936393
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
63946394
const RecordDecl *Record = RT->getDecl();
6395-
if (Record->isUnion()) {
6396-
IdxList.push_back(llvm::ConstantInt::get(IdxTy, 0));
6397-
llvm::Type *LLVMT = ConvertTypeForMem(T);
6398-
CharUnits Align = getContext().getTypeAlignInChars(T);
6399-
Address GEP =
6400-
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "union.gep");
6401-
AccessList.push_back({GEP, NULL});
6402-
FlatTypes.push_back(T);
6403-
continue;
6404-
}
6395+
assert(!Record->isUnion() && "Union types not supported in flat cast.");
6396+
64056397
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
64066398

64076399
llvm::SmallVector<QualType, 16> FieldTypes;

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,28 @@ static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
501501
// ^^ Flattened accesses to DestVal we want to store into
502502
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
503503

504-
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
505-
SrcTy = VT->getElementType();
506-
assert(StoreGEPList.size() <= VT->getNumElements() &&
507-
"Cannot perform HLSL flat cast when vector source \
508-
object has less elements than flattened destination \
509-
object.");
510-
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
511-
llvm::Value *Load =
512-
CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
513-
llvm::Value *Cast =
514-
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);
515-
516-
// store back
517-
llvm::Value *Idx = StoreGEPList[i].second;
518-
if (Idx) {
519-
llvm::Value *V =
520-
CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
521-
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
522-
}
523-
CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
504+
assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
505+
const VectorType *VT = SrcTy->getAs<VectorType>();
506+
SrcTy = VT->getElementType();
507+
assert(StoreGEPList.size() <= VT->getNumElements() &&
508+
"Cannot perform HLSL flat cast when vector source \
509+
object has less elements than flattened destination \
510+
object.");
511+
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
512+
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
513+
llvm::Value *Cast =
514+
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);
515+
516+
// store back
517+
llvm::Value *Idx = StoreGEPList[i].second;
518+
if (Idx) {
519+
llvm::Value *V =
520+
CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
521+
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
524522
}
525-
return;
523+
CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
526524
}
527-
llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
525+
return;
528526
}
529527

530528
// emit a flat cast where the RHS is an aggregate

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2792,11 +2792,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27922792
SourceLocation Loc = CE->getExprLoc();
27932793
QualType SrcTy = E->getType();
27942794

2795-
if (RV.isAggregate()) { // RHS is an aggregate
2796-
Address SrcVal = RV.getAggregateAddress();
2797-
return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
2798-
}
2799-
llvm_unreachable("Not a valid HLSL Flat Cast.");
2795+
assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
2796+
// RHS is an aggregate
2797+
Address SrcVal = RV.getAggregateAddress();
2798+
return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
28002799
}
28012800
} // end of switch
28022801

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,7 +2477,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
24772477
}
24782478

24792479
// Can we perform an HLSL Flattened cast?
2480-
// TODO: update this code when matrices are added
2480+
// TODO: update this code when matrices are added; see issue #88060
24812481
bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
24822482

24832483
// Don't handle casts where LHS and RHS are any combination of scalar/vector
@@ -2500,11 +2500,20 @@ bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
25002500
if (SrcTypes.size() < DestTypes.size())
25012501
return false;
25022502

2503-
for (unsigned i = 0; i < DestTypes.size() && i < SrcTypes.size(); i++) {
2504-
if (!CanPerformScalarCast(SrcTypes[i], DestTypes[i])) {
2503+
unsigned I;
2504+
for (I = 0; I < DestTypes.size() && I < SrcTypes.size(); I++) {
2505+
if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
2506+
return false;
2507+
if (!CanPerformScalarCast(SrcTypes[I], DestTypes[I])) {
25052508
return false;
25062509
}
25072510
}
2511+
2512+
// check the rest of the source type for unions.
2513+
for (; I < SrcTypes.size(); I++) {
2514+
if (SrcTypes[I]->isUnionType())
2515+
return false;
2516+
}
25082517
return true;
25092518
}
25102519

0 commit comments

Comments
 (0)