@@ -491,6 +491,79 @@ static bool isTrivialFiller(Expr *E) {
491491 return false ;
492492}
493493
494+ // emit a flat cast where the RHS is a scalar, including vector
495+ static void EmitHLSLScalarFlatCast (CodeGenFunction &CGF, Address DestVal,
496+ QualType DestTy, llvm::Value *SrcVal,
497+ QualType SrcTy, SourceLocation Loc) {
498+ // Flatten our destination
499+ SmallVector<QualType, 16 > DestTypes; // Flattened type
500+ SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
501+ // ^^ Flattened accesses to DestVal we want to store into
502+ CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
503+
504+ assert (SrcTy->isVectorType () && " HLSL Flat cast doesn't handle splatting." );
505+ const VectorType *VT = SrcTy->getAs <VectorType>();
506+ SrcTy = VT->getElementType ();
507+ assert (StoreGEPList.size () <= VT->getNumElements () &&
508+ " Cannot perform HLSL flat cast when vector source \
509+ object has less elements than flattened destination \
510+ object." );
511+ for (unsigned I = 0 , Size = StoreGEPList.size (); I < Size; I++) {
512+ llvm::Value *Load = CGF.Builder .CreateExtractElement (SrcVal, I, " vec.load" );
513+ llvm::Value *Cast =
514+ CGF.EmitScalarConversion (Load, SrcTy, DestTypes[I], Loc);
515+
516+ // store back
517+ llvm::Value *Idx = StoreGEPList[I].second ;
518+ if (Idx) {
519+ llvm::Value *V =
520+ CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
521+ Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
522+ }
523+ CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
524+ }
525+ return ;
526+ }
527+
528+ // emit a flat cast where the RHS is an aggregate
529+ static void EmitHLSLElementwiseCast (CodeGenFunction &CGF, Address DestVal,
530+ QualType DestTy, Address SrcVal,
531+ QualType SrcTy, SourceLocation Loc) {
532+ // Flatten our destination
533+ SmallVector<QualType, 16 > DestTypes; // Flattened type
534+ SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
535+ // ^^ Flattened accesses to DestVal we want to store into
536+ CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
537+ // Flatten our src
538+ SmallVector<QualType, 16 > SrcTypes; // Flattened type
539+ SmallVector<std::pair<Address, llvm::Value *>, 16 > LoadGEPList;
540+ // ^^ Flattened accesses to SrcVal we want to load from
541+ CGF.FlattenAccessAndType (SrcVal, SrcTy, LoadGEPList, SrcTypes);
542+
543+ assert (StoreGEPList.size () <= LoadGEPList.size () &&
544+ " Cannot perform HLSL flat cast when flattened source object \
545+ has less elements than flattened destination object." );
546+ // apply casts to what we load from LoadGEPList
547+ // and store result in Dest
548+ for (unsigned I = 0 , E = StoreGEPList.size (); I < E; I++) {
549+ llvm::Value *Idx = LoadGEPList[I].second ;
550+ llvm::Value *Load = CGF.Builder .CreateLoad (LoadGEPList[I].first , " load" );
551+ Load =
552+ Idx ? CGF.Builder .CreateExtractElement (Load, Idx, " vec.extract" ) : Load;
553+ llvm::Value *Cast =
554+ CGF.EmitScalarConversion (Load, SrcTypes[I], DestTypes[I], Loc);
555+
556+ // store back
557+ Idx = StoreGEPList[I].second ;
558+ if (Idx) {
559+ llvm::Value *V =
560+ CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
561+ Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
562+ }
563+ CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
564+ }
565+ }
566+
494567// / Emit initialization of an array from an initializer list. ExprToVisit must
495568// / be either an InitListEpxr a CXXParenInitListExpr.
496569void AggExprEmitter::EmitArrayInit (Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +963,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
890963 case CK_HLSLArrayRValue:
891964 Visit (E->getSubExpr ());
892965 break ;
893-
966+ case CK_HLSLElementwiseCast: {
967+ Expr *Src = E->getSubExpr ();
968+ QualType SrcTy = Src->getType ();
969+ RValue RV = CGF.EmitAnyExpr (Src);
970+ QualType DestTy = E->getType ();
971+ Address DestVal = Dest.getAddress ();
972+ SourceLocation Loc = E->getExprLoc ();
973+
974+ if (RV.isScalar ()) {
975+ llvm::Value *SrcVal = RV.getScalarVal ();
976+ EmitHLSLScalarFlatCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
977+ } else {
978+ assert (RV.isAggregate () &&
979+ " Can't perform HLSL Aggregate cast on a complex type." );
980+ Address SrcVal = RV.getAggregateAddress ();
981+ EmitHLSLElementwiseCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
982+ }
983+ break ;
984+ }
894985 case CK_NoOp:
895986 case CK_UserDefinedConversion:
896987 case CK_ConstructorConversion:
@@ -1461,6 +1552,7 @@ static bool castPreservesZero(const CastExpr *CE) {
14611552 case CK_NonAtomicToAtomic:
14621553 case CK_AtomicToNonAtomic:
14631554 case CK_HLSLVectorTruncation:
1555+ case CK_HLSLElementwiseCast:
14641556 return true ;
14651557
14661558 case CK_BaseToDerivedMemberPointer:
0 commit comments