Skip to content

Commit 2e932a5

Browse files
committed
Flat casts WIP
1 parent e6aec2c commit 2e932a5

File tree

15 files changed

+384
-4
lines changed

15 files changed

+384
-4
lines changed

clang/include/clang/AST/OperationKinds.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
367367
// Non-decaying array RValue cast (HLSL only).
368368
CAST_OPERATION(HLSLArrayRValue)
369369

370+
// Aggregate by Value cast (HLSL only).
371+
CAST_OPERATION(HLSLAggregateCast)
372+
370373
//===- Binary Operations -------------------------------------------------===//
371374
// Operators listed in order of precedence.
372375
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class SemaHLSL : public SemaBase {
140140
// Diagnose whether the input ID is uint/unit2/uint3 type.
141141
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
142142

143+
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
144+
bool CanPerformAggregateCast(Expr *Src, QualType DestType);
143145
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
144146

145147
QualType getInoutParameterType(QualType Ty);

clang/lib/AST/Expr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,6 +1942,7 @@ bool CastExpr::CastConsistency() const {
19421942
case CK_FixedPointToBoolean:
19431943
case CK_HLSLArrayRValue:
19441944
case CK_HLSLVectorTruncation:
1945+
case CK_HLSLAggregateCast:
19451946
CheckNoBasePath:
19461947
assert(path_empty() && "Cast kind should not have a base path!");
19471948
break;

clang/lib/AST/ExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15733,6 +15733,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1573315733
case CK_IntegralToFixedPoint:
1573415734
case CK_MatrixCast:
1573515735
case CK_HLSLVectorTruncation:
15736+
case CK_HLSLAggregateCast:
1573615737
llvm_unreachable("invalid cast kind for complex value");
1573715738

1573815739
case CK_LValueToRValue:

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5320,6 +5320,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
53205320
case CK_MatrixCast:
53215321
case CK_HLSLVectorTruncation:
53225322
case CK_HLSLArrayRValue:
5323+
case CK_HLSLAggregateCast:
53235324
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
53245325

53255326
case CK_Dependent:
@@ -6358,3 +6359,86 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
63586359
LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
63596360
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
63606361
}
6362+
6363+
llvm::Value* CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
6364+
Address GEPAddress = GEP.first;
6365+
llvm::Value *Idx = GEP.second;
6366+
llvm::Value *V = Builder.CreateLoad(GEPAddress, "load");
6367+
if (Idx) { // loading from a vector so perform an extract as well
6368+
return Builder.CreateExtractElement(V, Idx, "vec.load");
6369+
}
6370+
return V;
6371+
}
6372+
6373+
llvm::Value* CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
6374+
llvm::Value *Val) {
6375+
Address GEPAddress = GEP.first;
6376+
llvm::Value *Idx = GEP.second;
6377+
if (Idx) {
6378+
llvm::Value *V = Builder.CreateLoad(GEPAddress, "load.for.insert");
6379+
return Builder.CreateInsertElement(V, Val, Idx);
6380+
} else {
6381+
return Builder.CreateStore(Val, GEPAddress);
6382+
}
6383+
}
6384+
6385+
void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
6386+
SmallVector<llvm::Value *, 4> &IdxList,
6387+
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
6388+
SmallVector<QualType> &FlatTypes) {
6389+
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(),32);
6390+
if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
6391+
uint64_t Size = CAT->getZExtSize();
6392+
for(unsigned i = 0; i < Size; i ++) {
6393+
// flatten each member of the array
6394+
// add index of this element to index list
6395+
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
6396+
IdxList.push_back(Idx);
6397+
// recur on this object
6398+
FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList, FlatTypes);
6399+
// remove index of this element from index list
6400+
IdxList.pop_back();
6401+
}
6402+
} else if (const RecordType *RT = SrcTy->getAs<RecordType>()) {
6403+
RecordDecl *Record = RT->getDecl();
6404+
const CGRecordLayout &RL = getTypes().getCGRecordLayout(Record);
6405+
// do I need to check if its a cxx record decl?
6406+
6407+
for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
6408+
fieldIter != fieldEnd; ++fieldIter) {
6409+
// get the field number
6410+
unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
6411+
// can we just do *fieldIter->getFieldIndex();
6412+
// add that index to the index list
6413+
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, FieldNum);
6414+
IdxList.push_back(Idx);
6415+
// recur on the field
6416+
FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
6417+
FlatTypes);
6418+
// remove index of this element from index list
6419+
IdxList.pop_back();
6420+
}
6421+
} else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
6422+
llvm::Type *VTy = ConvertTypeForMem(SrcTy);
6423+
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
6424+
Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
6425+
VTy, Align, "vector.gep");
6426+
for(unsigned i = 0; i < VT->getNumElements(); i ++) {
6427+
// add index to the list
6428+
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
6429+
// create gep. no need to recur since its always a scalar
6430+
// gep on vector is not recommended so combine gep with extract/insert
6431+
GEPList.push_back({GEP, Idx});
6432+
FlatTypes.push_back(VT->getElementType());
6433+
}
6434+
} else { // should be a scalar should we assert or check?
6435+
// create a gep
6436+
llvm::Type *Ty = ConvertTypeForMem(SrcTy);
6437+
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
6438+
Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
6439+
Ty, Align, "gep");
6440+
GEPList.push_back({GEP, NULL});
6441+
FlatTypes.push_back(SrcTy);
6442+
}
6443+
// target extension types?
6444+
}

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,70 @@ static bool isTrivialFiller(Expr *E) {
491491
return false;
492492
}
493493

494+
495+
496+
// emit a flat cast where the RHS is a scalar, including vector
497+
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
498+
QualType DestTy, llvm::Value *SrcVal,
499+
QualType SrcTy, SourceLocation Loc) {
500+
// Flatten our destination
501+
SmallVector<QualType> DestTypes; // Flattened type
502+
SmallVector<llvm::Value *, 4> IdxList;
503+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
504+
// ^^ Flattened accesses to DestVal we want to store into
505+
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
506+
DestTypes);
507+
508+
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
509+
SrcTy = VT->getElementType();
510+
assert(StoreGEPList.size() <= VT->getNumElements() &&
511+
"Cannot perform HLSL flat cast when vector source \
512+
object has less elements than flattened destination \
513+
object.");
514+
for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
515+
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, i,
516+
"vec.load");
517+
llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTy,
518+
DestTypes[i],
519+
Loc);
520+
CGF.PerformStore(StoreGEPList[i], Cast);
521+
}
522+
return;
523+
}
524+
llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
525+
}
526+
527+
// emit a flat cast where the RHS is an aggregate
528+
static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
529+
QualType DestTy, Address SrcVal,
530+
QualType SrcTy, SourceLocation Loc) {
531+
// Flatten our destination
532+
SmallVector<QualType> DestTypes; // Flattened type
533+
SmallVector<llvm::Value *, 4> IdxList;
534+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
535+
// ^^ Flattened accesses to DestVal we want to store into
536+
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
537+
DestTypes);
538+
// Flatten our src
539+
SmallVector<QualType> SrcTypes; // Flattened type
540+
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
541+
// ^^ Flattened accesses to SrcVal we want to load from
542+
IdxList.clear();
543+
CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);
544+
545+
assert(StoreGEPList.size() <= LoadGEPList.size() &&
546+
"Cannot perform HLSL flat cast when flattened source object \
547+
has less elements than flattened destination object.");
548+
// apply casts to what we load from LoadGEPList
549+
// and store result in Dest
550+
for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
551+
llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
552+
llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
553+
DestTypes[i], Loc);
554+
CGF.PerformStore(StoreGEPList[i], Cast);
555+
}
556+
}
557+
494558
/// Emit initialization of an array from an initializer list. ExprToVisit must
495559
/// be either an InitListEpxr a CXXParenInitListExpr.
496560
void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +954,24 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
890954
case CK_HLSLArrayRValue:
891955
Visit(E->getSubExpr());
892956
break;
893-
957+
case CK_HLSLAggregateCast: {
958+
Expr *Src = E->getSubExpr();
959+
QualType SrcTy = Src->getType();
960+
RValue RV = CGF.EmitAnyExpr(Src);
961+
QualType DestTy = E->getType();
962+
Address DestVal = Dest.getAddress();
963+
SourceLocation Loc = E->getExprLoc();
964+
965+
if (RV.isScalar()) {
966+
llvm::Value *SrcVal = RV.getScalarVal();
967+
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
968+
} else { // RHS is an aggregate
969+
assert(RV.isAggregate() &&
970+
"Can't perform HLSL Aggregate cast on a complex type.");
971+
Address SrcVal = RV.getAggregateAddress();
972+
EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
973+
}
974+
break; }
894975
case CK_NoOp:
895976
case CK_UserDefinedConversion:
896977
case CK_ConstructorConversion:

clang/lib/CodeGen/CGExprComplex.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
610610
case CK_MatrixCast:
611611
case CK_HLSLVectorTruncation:
612612
case CK_HLSLArrayRValue:
613+
case CK_HLSLAggregateCast:
613614
llvm_unreachable("invalid cast kind for complex value");
614615

615616
case CK_FloatingRealToComplex:

clang/lib/CodeGen/CGExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ class ConstExprEmitter
13351335
case CK_MatrixCast:
13361336
case CK_HLSLVectorTruncation:
13371337
case CK_HLSLArrayRValue:
1338+
case CK_HLSLAggregateCast:
13381339
return nullptr;
13391340
}
13401341
llvm_unreachable("Invalid CastKind");

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,35 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
22622262
return true;
22632263
}
22642264

2265+
// RHS is an aggregate type
2266+
static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
2267+
QualType RHSTy, QualType LHSTy,
2268+
SourceLocation Loc) {
2269+
SmallVector<llvm::Value *, 4> IdxList;
2270+
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
2271+
SmallVector<QualType> SrcTypes; // Flattened type
2272+
CGF.FlattenAccessAndType(RHSVal, RHSTy, IdxList, LoadGEPList, SrcTypes);
2273+
// LHS is either a vector or a builtin?
2274+
// if its a vector create a temp alloca to store into and return that
2275+
if (auto *VecTy = LHSTy->getAs<VectorType>()) {
2276+
llvm::Value *V = CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
2277+
// write to V.
2278+
for(unsigned i = 0; i < VecTy->getNumElements(); i ++) {
2279+
llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
2280+
llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
2281+
VecTy->getElementType(), Loc);
2282+
V = CGF.Builder.CreateInsertElement(V, Cast, i);
2283+
}
2284+
return V;
2285+
}
2286+
// i its a builtin just do an extract element or load.
2287+
assert(LHSTy->isBuiltinType() &&
2288+
"Destination type must be a vector or builtin type.");
2289+
// TODO add asserts about things being long enough
2290+
return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]),
2291+
LHSTy, SrcTypes[0], Loc);
2292+
}
2293+
22652294
// VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts
22662295
// have to handle a more broad range of conversions than explicit casts, as they
22672296
// handle things like function to ptr-to-function decay etc.
@@ -2752,7 +2781,17 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27522781
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
27532782
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27542783
}
2784+
case CK_HLSLAggregateCast: {
2785+
RValue RV = CGF.EmitAnyExpr(E);
2786+
SourceLocation Loc = CE->getExprLoc();
2787+
QualType SrcTy = E->getType();
27552788

2789+
if (RV.isAggregate()) { // RHS is an aggregate
2790+
Address SrcVal = RV.getAggregateAddress();
2791+
return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
2792+
}
2793+
llvm_unreachable("Not a valid HLSL Flat Cast.");
2794+
}
27562795
} // end of switch
27572796

27582797
llvm_unreachable("unknown scalar cast");

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,6 +4359,13 @@ class CodeGenFunction : public CodeGenTypeCache {
43594359
AggValueSlot slot = AggValueSlot::ignored());
43604360
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
43614361

4362+
llvm::Value *PerformLoad(std::pair<Address, llvm::Value *> &GEP);
4363+
llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP, llvm::Value *Val);
4364+
void FlattenAccessAndType(Address Val, QualType SrcTy,
4365+
SmallVector<llvm::Value *, 4> &IdxList,
4366+
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
4367+
SmallVector<QualType> &FlatTypes);
4368+
43624369
llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
43634370
const ObjCIvarDecl *Ivar);
43644371
llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,

0 commit comments

Comments
 (0)