Skip to content

Commit e038166

Browse files
committed
splat cast wip
1 parent 89709ad commit e038166

File tree

7 files changed

+97
-1
lines changed

7 files changed

+97
-1
lines changed

clang/include/clang/AST/OperationKinds.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
370370
// Aggregate by Value cast (HLSL only).
371371
CAST_OPERATION(HLSLAggregateCast)
372372

373+
// Splat cast for Aggregates (HLSL only).
374+
CAST_OPERATION(HLSLSplatCast)
375+
373376
//===- Binary Operations -------------------------------------------------===//
374377
// Operators listed in order of precedence.
375378
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class SemaHLSL : public SemaBase {
142142

143143
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
144144
bool CanPerformAggregateCast(Expr *Src, QualType DestType);
145+
bool CanPerformSplat(Expr *Src, QualType DestType);
145146
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
146147

147148
QualType getInoutParameterType(QualType Ty);

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
491491
return false;
492492
}
493493

494+
static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
495+
QualType DestTy, llvm::Value *SrcVal,
496+
QualType SrcTy, SourceLocation Loc) {
497+
// Flatten our destination
498+
SmallVector<QualType> DestTypes; // Flattened type
499+
SmallVector<llvm::Value *, 4> IdxList;
500+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
501+
// ^^ Flattened accesses to DestVal we want to store into
502+
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
503+
DestTypes);
504+
505+
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
506+
assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
507+
508+
SrcTy = VT->getElementType();
509+
SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
510+
"vec.load");
511+
}
512+
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
513+
for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
514+
llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
515+
DestTypes[i],
516+
Loc);
517+
CGF.PerformStore(StoreGEPList[i], Cast);
518+
}
519+
}
520+
494521
// emit a flat cast where the RHS is a scalar, including vector
495522
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
496523
QualType DestTy, llvm::Value *SrcVal,
@@ -965,6 +992,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
965992
case CK_HLSLArrayRValue:
966993
Visit(E->getSubExpr());
967994
break;
995+
case CK_HLSLSplatCast: {
996+
Expr *Src = E->getSubExpr();
997+
QualType SrcTy = Src->getType();
998+
RValue RV = CGF.EmitAnyExpr(Src);
999+
QualType DestTy = E->getType();
1000+
Address DestVal = Dest.getAddress();
1001+
SourceLocation Loc = E->getExprLoc();
1002+
1003+
if (RV.isScalar()) {
1004+
llvm::Value *SrcVal = RV.getScalarVal();
1005+
EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
1006+
break;
1007+
}
1008+
llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
1009+
}
9681010
case CK_HLSLAggregateCast: {
9691011
Expr *Src = E->getSubExpr();
9701012
QualType SrcTy = Src->getType();

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27872787
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
27882788
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27892789
}
2790+
case CK_HLSLSplatCast: {
2791+
assert(DestTy->isVectorType() && "Destination type must be a vector.");
2792+
auto *DestVecTy = DestTy->getAs<VectorType>();
2793+
QualType SrcTy = E->getType();
2794+
SourceLocation Loc = CE->getExprLoc();
2795+
Value *V = Visit(const_cast<Expr *>(E));
2796+
if (auto *VecTy = SrcTy->getAs<VectorType>()) {
2797+
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2798+
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2799+
SrcTy = VecTy->getElementType();
2800+
}
2801+
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
2802+
Value *Cast = EmitScalarConversion(V, SrcTy,
2803+
DestVecTy->getElementType(), Loc);
2804+
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
2805+
}
27902806
case CK_HLSLAggregateCast: {
27912807
RValue RV = CGF.EmitAnyExpr(E);
27922808
SourceLocation Loc = CE->getExprLoc();

clang/lib/Sema/Sema.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
708708
case CK_NonAtomicToAtomic:
709709
case CK_HLSLArrayRValue:
710710
case CK_HLSLAggregateCast:
711+
case CK_HLSLSplatCast:
711712
break;
712713
}
713714
}

clang/lib/Sema/SemaCast.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2772,9 +2772,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27722772
CheckedConversionKind CCK = FunctionalStyle
27732773
? CheckedConversionKind::FunctionalCast
27742774
: CheckedConversionKind::CStyleCast;
2775+
27752776
// This case should not trigger on regular vector splat
2776-
// vector cast, vector truncation, or special hlsl splat cases
27772777
QualType SrcTy = SrcExpr.get()->getType();
2778+
if (Self.getLangOpts().HLSL &&
2779+
Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
2780+
Kind = CK_HLSLSplatCast;
2781+
return;
2782+
}
2783+
2784+
// This case should not trigger on regular vector cast, vector truncation
27782785
if (Self.getLangOpts().HLSL &&
27792786
Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
27802787
if (SrcTy->isConstantArrayType())

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,6 +2476,32 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
24762476
llvm_unreachable("Unhandled scalar cast");
24772477
}
24782478

2479+
// Can perform an HLSL splat cast if the Dest is an aggregate and the
2480+
// Src is a scalar or a vector of length 1
2481+
// Or if Dest is a vector and Src is a vector of length 1
2482+
bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
2483+
2484+
QualType SrcTy = Src->getType();
2485+
if (SrcTy->isScalarType() && DestTy->isVectorType())
2486+
return false;
2487+
2488+
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
2489+
if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
2490+
return false;
2491+
2492+
if (SrcVecTy)
2493+
SrcTy = SrcVecTy->getElementType();
2494+
2495+
llvm::SmallVector<QualType> DestTypes;
2496+
BuildFlattenedTypeList(DestTy, DestTypes);
2497+
2498+
for(unsigned i = 0; i < DestTypes.size(); i ++) {
2499+
if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
2500+
return false;
2501+
}
2502+
return true;
2503+
}
2504+
24792505
// Can we perform an HLSL Flattened cast?
24802506
// TODO: update this code when matrices are added
24812507
bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {

0 commit comments

Comments
 (0)