Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
// Non-decaying array RValue cast (HLSL only).
CAST_OPERATION(HLSLArrayRValue)

// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLAggregateCast)

//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class SemaHLSL : public SemaBase {
// Diagnose whether the input ID is uint/unit2/uint3 type.
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);

bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool CanPerformAggregateCast(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);

QualType getInoutParameterType(QualType Ty);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,7 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointToBoolean:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14857,6 +14857,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointCast:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
// TODO does CK_HLSLAggregateCast belong here?
llvm_unreachable("invalid cast kind for integral value");

case CK_BitCast:
Expand Down Expand Up @@ -15733,6 +15734,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
87 changes: 87 additions & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5320,6 +5320,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down Expand Up @@ -6358,3 +6359,89 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
}

void CodeGenFunction::FlattenAccessAndType(
Address Addr, QualType AddrType,
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
SmallVectorImpl<QualType> &FlatTypes) {
// WorkList is list of type we are processing + the Index List to access
// the field of that type in Addr for use in a GEP
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
16>
WorkList;
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
WorkList.push_back(
{AddrType,
{llvm::ConstantInt::get(
IdxTy,
0)}}); // Addr should be a pointer so we need to 'dereference' it

while (!WorkList.empty()) {
std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>> P =
WorkList.pop_back_val();
QualType T = P.first;
llvm::SmallVector<llvm::Value *, 4> IdxList = P.second;
T = T.getCanonicalType().getUnqualifiedType();
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
uint64_t Size = CAT->getZExtSize();
for (int64_t i = Size - 1; i > -1; i--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, i));
WorkList.insert(WorkList.end(), {CAT->getElementType(), IdxListCopy});
}
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
const RecordDecl *Record = RT->getDecl();
if (Record->isUnion()) {
IdxList.push_back(llvm::ConstantInt::get(IdxTy, 0));
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "union.gep");
AccessList.push_back({GEP, NULL});
FlatTypes.push_back(T);
continue;
}
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);

llvm::SmallVector<QualType, 16> FieldTypes;
if (CXXD && CXXD->isStandardLayout())
Record = CXXD->getStandardLayoutBaseWithFields();

// deal with potential base classes
if (CXXD && !CXXD->isStandardLayout()) {
for (auto &Base : CXXD->bases())
FieldTypes.push_back(Base.getType());
}

for (auto *FD : Record->fields())
FieldTypes.push_back(FD->getType());

for (int64_t i = FieldTypes.size() - 1; i > -1; i--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, i));
WorkList.insert(WorkList.end(), {FieldTypes[i], IdxListCopy});
}
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
for (unsigned i = 0; i < VT->getNumElements(); i++) {
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
// gep on vector fields is not recommended so combine gep with
// extract/insert
AccessList.push_back({GEP, Idx});
FlatTypes.push_back(VT->getElementType());
}
} else {
// a scalar/builtin type
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
AccessList.push_back({GEP, NULL});
FlatTypes.push_back(T);
}
}
}
96 changes: 95 additions & 1 deletion clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,81 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
SrcTy = VT->getElementType();
assert(StoreGEPList.size() <= VT->getNumElements() &&
"Cannot perform HLSL flat cast when vector source \
object has less elements than flattened destination \
object.");
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
llvm::Value *Load =
CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);

// store back
llvm::Value *Idx = StoreGEPList[i].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
}
return;
}
llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
}

// emit a flat cast where the RHS is an aggregate
static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, Address SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
// Flatten our src
SmallVector<QualType, 16> SrcTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
// ^^ Flattened accesses to SrcVal we want to load from
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);

assert(StoreGEPList.size() <= LoadGEPList.size() &&
"Cannot perform HLSL flat cast when flattened source object \
has less elements than flattened destination object.");
// apply casts to what we load from LoadGEPList
// and store result in Dest
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
llvm::Value *Idx = LoadGEPList[i].second;
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[i].first, "load");
Load =
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTypes[i], DestTypes[i], Loc);

// store back
Idx = StoreGEPList[i].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
}
}

/// Emit initialization of an array from an initializer list. ExprToVisit must
/// be either an InitListEpxr a CXXParenInitListExpr.
void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
Expand Down Expand Up @@ -890,7 +965,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;

case CK_HLSLAggregateCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();

if (RV.isScalar()) {
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
} else { // RHS is an aggregate
assert(RV.isAggregate() &&
"Can't perform HLSL Aggregate cast on a complex type.");
Address SrcVal = RV.getAggregateAddress();
EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
}
break;
}
case CK_NoOp:
case CK_UserDefinedConversion:
case CK_ConstructorConversion:
Expand Down Expand Up @@ -1461,6 +1554,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_NonAtomicToAtomic:
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
// TODO does CK_HLSLAggregateCast preserve zero?
return true;

case CK_BaseToDerivedMemberPointer:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,7 @@ class ConstExprEmitter
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
45 changes: 45 additions & 0 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,41 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
return true;
}

// RHS is an aggregate type
static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
QualType RHSTy, QualType LHSTy,
SourceLocation Loc) {
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
SmallVector<QualType, 16> SrcTypes; // Flattened type
CGF.FlattenAccessAndType(RHSVal, RHSTy, LoadGEPList, SrcTypes);
// LHS is either a vector or a builtin?
// if its a vector create a temp alloca to store into and return that
if (auto *VecTy = LHSTy->getAs<VectorType>()) {
llvm::Value *V =
CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
// write to V.
for (unsigned i = 0; i < VecTy->getNumElements(); i++) {
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[i].first, "load");
llvm::Value *Idx = LoadGEPList[i].second;
Load = Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract")
: Load;
llvm::Value *Cast = CGF.EmitScalarConversion(
Load, SrcTypes[i], VecTy->getElementType(), Loc);
V = CGF.Builder.CreateInsertElement(V, Cast, i);
}
return V;
}
// i its a builtin just do an extract element or load.
assert(LHSTy->isBuiltinType() &&
"Destination type must be a vector or builtin type.");
// TODO add asserts about things being long enough
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[0].first, "load");
llvm::Value *Idx = LoadGEPList[0].second;
Load =
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
return CGF.EmitScalarConversion(Load, LHSTy, SrcTypes[0], Loc);
}

// VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts
// have to handle a more broad range of conversions than explicit casts, as they
// handle things like function to ptr-to-function decay etc.
Expand Down Expand Up @@ -2752,7 +2787,17 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLAggregateCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
QualType SrcTy = E->getType();

if (RV.isAggregate()) { // RHS is an aggregate
Address SrcVal = RV.getAggregateAddress();
return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
}
llvm_unreachable("Not a valid HLSL Flat Cast.");
}
} // end of switch

llvm_unreachable("unknown scalar cast");
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4359,6 +4359,11 @@ class CodeGenFunction : public CodeGenTypeCache {
AggValueSlot slot = AggValueSlot::ignored());
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);

void FlattenAccessAndType(
Address Addr, QualType AddrTy,
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
SmallVectorImpl<QualType> &FlatTypes);

llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
const ObjCIvarDecl *Ivar);
llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
llvm_unreachable("OpenCL-specific cast in Objective-C?");

case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;

Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
case CK_ToVoid:
case CK_NonAtomicToAtomic:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
break;
}
}
Expand Down
Loading
Loading