Skip to content

Commit e994824

Browse files
committed
splat cast wip
1 parent 2feced1 commit e994824

File tree

7 files changed

+97
-1
lines changed

7 files changed

+97
-1
lines changed

clang/include/clang/AST/OperationKinds.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
370370
// Aggregate by Value cast (HLSL only).
371371
CAST_OPERATION(HLSLElementwiseCast)
372372

373+
// Splat cast for Aggregates (HLSL only).
374+
CAST_OPERATION(HLSLSplatCast)
375+
373376
//===- Binary Operations -------------------------------------------------===//
374377
// Operators listed in order of precedence.
375378
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
144144
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
145145
bool ContainsBitField(QualType BaseTy);
146146
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
147+
bool CanPerformSplat(Expr *Src, QualType DestType);
147148
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
148149

149150
QualType getInoutParameterType(QualType Ty);

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
491491
return false;
492492
}
493493

494+
static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
495+
QualType DestTy, llvm::Value *SrcVal,
496+
QualType SrcTy, SourceLocation Loc) {
497+
// Flatten our destination
498+
SmallVector<QualType> DestTypes; // Flattened type
499+
SmallVector<llvm::Value *, 4> IdxList;
500+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
501+
// ^^ Flattened accesses to DestVal we want to store into
502+
CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
503+
DestTypes);
504+
505+
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
506+
assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
507+
508+
SrcTy = VT->getElementType();
509+
SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
510+
"vec.load");
511+
}
512+
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
513+
for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
514+
llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
515+
DestTypes[i],
516+
Loc);
517+
CGF.PerformStore(StoreGEPList[i], Cast);
518+
}
519+
}
520+
494521
// emit a flat cast where the RHS is a scalar, including vector
495522
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
496523
QualType DestTy, llvm::Value *SrcVal,
@@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
963990
case CK_HLSLArrayRValue:
964991
Visit(E->getSubExpr());
965992
break;
993+
case CK_HLSLSplatCast: {
994+
Expr *Src = E->getSubExpr();
995+
QualType SrcTy = Src->getType();
996+
RValue RV = CGF.EmitAnyExpr(Src);
997+
QualType DestTy = E->getType();
998+
Address DestVal = Dest.getAddress();
999+
SourceLocation Loc = E->getExprLoc();
1000+
1001+
if (RV.isScalar()) {
1002+
llvm::Value *SrcVal = RV.getScalarVal();
1003+
EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
1004+
break;
1005+
}
1006+
llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
1007+
}
9661008
case CK_HLSLElementwiseCast: {
9671009
Expr *Src = E->getSubExpr();
9681010
QualType SrcTy = Src->getType();

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27952795
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
27962796
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27972797
}
2798+
case CK_HLSLSplatCast: {
2799+
assert(DestTy->isVectorType() && "Destination type must be a vector.");
2800+
auto *DestVecTy = DestTy->getAs<VectorType>();
2801+
QualType SrcTy = E->getType();
2802+
SourceLocation Loc = CE->getExprLoc();
2803+
Value *V = Visit(const_cast<Expr *>(E));
2804+
if (auto *VecTy = SrcTy->getAs<VectorType>()) {
2805+
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2806+
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2807+
SrcTy = VecTy->getElementType();
2808+
}
2809+
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
2810+
Value *Cast = EmitScalarConversion(V, SrcTy,
2811+
DestVecTy->getElementType(), Loc);
2812+
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
2813+
}
27982814
case CK_HLSLElementwiseCast: {
27992815
RValue RV = CGF.EmitAnyExpr(E);
28002816
SourceLocation Loc = CE->getExprLoc();

clang/lib/Sema/Sema.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
709709
case CK_ToVoid:
710710
case CK_NonAtomicToAtomic:
711711
case CK_HLSLArrayRValue:
712+
case CK_HLSLSplatCast:
712713
break;
713714
}
714715
}

clang/lib/Sema/SemaCast.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27762776
CheckedConversionKind CCK = FunctionalStyle
27772777
? CheckedConversionKind::FunctionalCast
27782778
: CheckedConversionKind::CStyleCast;
2779+
27792780
// This case should not trigger on regular vector splat
2780-
// vector cast, vector truncation, or special hlsl splat cases
27812781
QualType SrcTy = SrcExpr.get()->getType();
2782+
if (Self.getLangOpts().HLSL &&
2783+
Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
2784+
Kind = CK_HLSLSplatCast;
2785+
return;
2786+
}
2787+
2788+
// This case should not trigger on regular vector cast, vector truncation
27822789
if (Self.getLangOpts().HLSL &&
27832790
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
27842791
if (SrcTy->isConstantArrayType())

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
28042804
return false;
28052805
}
28062806

2807+
// Can perform an HLSL splat cast if the Dest is an aggregate and the
2808+
// Src is a scalar or a vector of length 1
2809+
// Or if Dest is a vector and Src is a vector of length 1
2810+
bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
2811+
2812+
QualType SrcTy = Src->getType();
2813+
if (SrcTy->isScalarType() && DestTy->isVectorType())
2814+
return false;
2815+
2816+
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
2817+
if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
2818+
return false;
2819+
2820+
if (SrcVecTy)
2821+
SrcTy = SrcVecTy->getElementType();
2822+
2823+
llvm::SmallVector<QualType> DestTypes;
2824+
BuildFlattenedTypeList(DestTy, DestTypes);
2825+
2826+
for(unsigned i = 0; i < DestTypes.size(); i ++) {
2827+
if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
2828+
return false;
2829+
}
2830+
return true;
2831+
}
2832+
28072833
// Can we perform an HLSL Elementwise cast?
28082834
// TODO: update this code when matrices are added; see issue #88060
28092835
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

0 commit comments

Comments
 (0)