Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
// Non-decaying array RValue cast (HLSL only).
CAST_OPERATION(HLSLArrayRValue)

// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLAggregateCast)

//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class SemaHLSL : public SemaBase {
// Diagnose whether the input ID is uint/unit2/uint3 type.
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);

bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool CanPerformAggregateCast(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);

QualType getInoutParameterType(QualType Ty);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,7 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointToBoolean:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15733,6 +15733,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
86 changes: 86 additions & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5320,6 +5320,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down Expand Up @@ -6358,3 +6359,88 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
}

llvm::Value *
CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
Address GEPAddress = GEP.first;
llvm::Value *Idx = GEP.second;
llvm::Value *V = Builder.CreateLoad(GEPAddress, "load");
if (Idx) { // loading from a vector so perform an extract as well
return Builder.CreateExtractElement(V, Idx, "vec.load");
}
return V;
}

llvm::Value *
CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
llvm::Value *Val) {
Address GEPAddress = GEP.first;
llvm::Value *Idx = GEP.second;
if (Idx) {
llvm::Value *V = Builder.CreateLoad(GEPAddress, "load.for.insert");
return Builder.CreateInsertElement(V, Val, Idx);
} else {
return Builder.CreateStore(Val, GEPAddress);
}
}

void CodeGenFunction::FlattenAccessAndType(
Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
SmallVector<QualType> &FlatTypes) {
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
uint64_t Size = CAT->getZExtSize();
for (unsigned i = 0; i < Size; i++) {
// flatten each member of the array
// add index of this element to index list
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
IdxList.push_back(Idx);
// recur on this object
FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList,
FlatTypes);
// remove index of this element from index list
IdxList.pop_back();
}
} else if (const RecordType *RT = SrcTy->getAs<RecordType>()) {
RecordDecl *Record = RT->getDecl();
const CGRecordLayout &RL = getTypes().getCGRecordLayout(Record);
// do I need to check if its a cxx record decl?

for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
fieldIter != fieldEnd; ++fieldIter) {
// get the field number
unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
// can we just do *fieldIter->getFieldIndex();
// add that index to the index list
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, FieldNum);
IdxList.push_back(Idx);
// recur on the field
FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
FlatTypes);
// remove index of this element from index list
IdxList.pop_back();
}
} else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
llvm::Type *VTy = ConvertTypeForMem(SrcTy);
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
Address GEP =
Builder.CreateInBoundsGEP(Val, IdxList, VTy, Align, "vector.gep");
for (unsigned i = 0; i < VT->getNumElements(); i++) {
// add index to the list
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
// create gep. no need to recur since its always a scalar
// gep on vector is not recommended so combine gep with extract/insert
GEPList.push_back({GEP, Idx});
FlatTypes.push_back(VT->getElementType());
}
} else { // should be a scalar should we assert or check?
// create a gep
llvm::Type *Ty = ConvertTypeForMem(SrcTy);
CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
Address GEP = Builder.CreateInBoundsGEP(Val, IdxList, Ty, Align, "gep");
GEPList.push_back({GEP, NULL});
FlatTypes.push_back(SrcTy);
}
// target extension types?
}
79 changes: 78 additions & 1 deletion clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,65 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<llvm::Value *, 4> IdxList;
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);

if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
SrcTy = VT->getElementType();
assert(StoreGEPList.size() <= VT->getNumElements() &&
"Cannot perform HLSL flat cast when vector source \
object has less elements than flattened destination \
object.");
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
llvm::Value *Load =
CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);
CGF.PerformStore(StoreGEPList[i], Cast);
}
return;
}
llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
}

// emit a flat cast where the RHS is an aggregate
static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, Address SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<llvm::Value *, 4> IdxList;
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
// Flatten our src
SmallVector<QualType> SrcTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
// ^^ Flattened accesses to SrcVal we want to load from
IdxList.clear();
CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);

assert(StoreGEPList.size() <= LoadGEPList.size() &&
"Cannot perform HLSL flat cast when flattened source object \
has less elements than flattened destination object.");
// apply casts to what we load from LoadGEPList
// and store result in Dest
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTypes[i], DestTypes[i], Loc);
CGF.PerformStore(StoreGEPList[i], Cast);
}
}

/// Emit initialization of an array from an initializer list. ExprToVisit must
/// be either an InitListEpxr a CXXParenInitListExpr.
void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
Expand Down Expand Up @@ -890,7 +949,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;

case CK_HLSLAggregateCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();

if (RV.isScalar()) {
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
} else { // RHS is an aggregate
assert(RV.isAggregate() &&
"Can't perform HLSL Aggregate cast on a complex type.");
Address SrcVal = RV.getAggregateAddress();
EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
}
break;
}
case CK_NoOp:
case CK_UserDefinedConversion:
case CK_ConstructorConversion:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,7 @@ class ConstExprEmitter
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
40 changes: 40 additions & 0 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,36 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
return true;
}

// RHS is an aggregate type
static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
QualType RHSTy, QualType LHSTy,
SourceLocation Loc) {
SmallVector<llvm::Value *, 4> IdxList;
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
SmallVector<QualType> SrcTypes; // Flattened type
CGF.FlattenAccessAndType(RHSVal, RHSTy, IdxList, LoadGEPList, SrcTypes);
// LHS is either a vector or a builtin?
// if its a vector create a temp alloca to store into and return that
if (auto *VecTy = LHSTy->getAs<VectorType>()) {
llvm::Value *V =
CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
// write to V.
for (unsigned i = 0; i < VecTy->getNumElements(); i++) {
llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
llvm::Value *Cast = CGF.EmitScalarConversion(
Load, SrcTypes[i], VecTy->getElementType(), Loc);
V = CGF.Builder.CreateInsertElement(V, Cast, i);
}
return V;
}
// i its a builtin just do an extract element or load.
assert(LHSTy->isBuiltinType() &&
"Destination type must be a vector or builtin type.");
// TODO add asserts about things being long enough
return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]), LHSTy,
SrcTypes[0], Loc);
}

// VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts
// have to handle a more broad range of conversions than explicit casts, as they
// handle things like function to ptr-to-function decay etc.
Expand Down Expand Up @@ -2752,7 +2782,17 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLAggregateCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
QualType SrcTy = E->getType();

if (RV.isAggregate()) { // RHS is an aggregate
Address SrcVal = RV.getAggregateAddress();
return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
}
llvm_unreachable("Not a valid HLSL Flat Cast.");
}
} // end of switch

llvm_unreachable("unknown scalar cast");
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4359,6 +4359,14 @@ class CodeGenFunction : public CodeGenTypeCache {
AggValueSlot slot = AggValueSlot::ignored());
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);

llvm::Value *PerformLoad(std::pair<Address, llvm::Value *> &GEP);
llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP,
llvm::Value *Val);
void FlattenAccessAndType(
Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
SmallVector<QualType> &FlatTypes);

llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
const ObjCIvarDecl *Ivar);
llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
llvm_unreachable("OpenCL-specific cast in Objective-C?");

case CK_HLSLVectorTruncation:
case CK_HLSLAggregateCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;

Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
case CK_ToVoid:
case CK_NonAtomicToAtomic:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateCast:
break;
}
}
Expand Down
22 changes: 19 additions & 3 deletions clang/lib/Sema/SemaCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Sema/Initialization.h"
#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaObjC.h"
#include "clang/Sema/SemaRISCV.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -2768,6 +2769,24 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
return;
}

CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
// todo what else should i be doing lvalue to rvalue cast for?
// why dont they do it for records below?
// This case should not trigger on regular vector splat
// Or vector cast or vector truncation.
QualType SrcTy = SrcExpr.get()->getType();
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
if (SrcTy->isConstantArrayType())
SrcExpr = Self.ImpCastExprToType(
SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy),
CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
Kind = CK_HLSLAggregateCast;
return;
}

if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
!isPlaceholder(BuiltinType::Overload)) {
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
Expand Down Expand Up @@ -2820,9 +2839,6 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
if (isValidCast(tcr))
Kind = CK_NoOp;

CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
if (tcr == TC_NotApplicable) {
tcr = TryAddressSpaceCast(Self, SrcExpr, DestType, /*CStyle*/ true, msg,
Kind);
Expand Down
Loading
Loading