Skip to content

Commit 5fb0b9f

Browse files
committed
In vector case do more heavy lifting in sema so codegen can reuse VectorSplat codegen
1 parent 7331af6 commit 5fb0b9f

File tree

4 files changed

+15
-18
lines changed

4 files changed

+15
-18
lines changed

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class SemaHLSL : public SemaBase {
144144
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
145145
bool ContainsBitField(QualType BaseTy);
146146
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
147-
bool CanPerformSplatCast(Expr *Src, QualType DestType);
147+
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
148148
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
149149

150150
QualType getInoutParameterType(QualType Ty);

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
26432643
return EmitScalarConversion(Visit(E), E->getType(), DestTy,
26442644
CE->getExprLoc());
26452645
}
2646+
// CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
2647+
// Casts were inserted in Sema to Cast the Src Expr to a Scalar and
2648+
// To perform any necessary Scalar Cast, so this Cast can be handled
2649+
// by the regular Vector Splat cast code.
2650+
case CK_HLSLAggregateSplatCast:
26462651
case CK_VectorSplat: {
26472652
llvm::Type *DstTy = ConvertType(DestTy);
26482653
Value *Elt = Visit(const_cast<Expr *>(E));
@@ -2795,22 +2800,6 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27952800
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
27962801
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27972802
}
2798-
case CK_HLSLAggregateSplatCast: {
2799-
// This cast should only handle splatting from vectors of length 1.
2800-
// But in Sema a cast should have been inserted to convert the vec1 to a
2801-
// scalar
2802-
assert(DestTy->isVectorType() && "Destination type must be a vector.");
2803-
auto *DestVecTy = DestTy->getAs<VectorType>();
2804-
QualType SrcTy = E->getType();
2805-
SourceLocation Loc = CE->getExprLoc();
2806-
Value *V = Visit(const_cast<Expr *>(E));
2807-
assert(SrcTy->isBuiltinType() && "Invalid HLSL Aggregate splat cast.");
2808-
2809-
Value *Cast =
2810-
EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
2811-
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
2812-
"splat");
2813-
}
28142803
case CK_HLSLElementwiseCast: {
28152804
RValue RV = CGF.EmitAnyExpr(E);
28162805
SourceLocation Loc = CE->getExprLoc();

clang/lib/Sema/SemaCast.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2793,13 +2793,20 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27932793
// If the relative order of this and the HLSLElementWise cast checks
27942794
// are changed, it might change which cast handles what in a few cases
27952795
if (Self.getLangOpts().HLSL &&
2796-
Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
2796+
Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
27972797
const VectorType *VT = SrcTy->getAs<VectorType>();
27982798
// change splat from vec1 case to splat from scalar
27992799
if (VT && VT->getNumElements() == 1)
28002800
SrcExpr = Self.ImpCastExprToType(
28012801
SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
28022802
SrcExpr.get()->getValueKind(), nullptr, CCK);
2803+
// Inserting a scalar cast here allows for a simplified codegen in
2804+
// the case the destTy is a vector
2805+
if (const VectorType *DVT = DestType->getAs<VectorType>())
2806+
SrcExpr = Self.ImpCastExprToType(
2807+
SrcExpr.get(), DVT->getElementType(),
2808+
Self.PrepareScalarCast(SrcExpr, DVT->getElementType()),
2809+
SrcExpr.get()->getValueKind(), nullptr, CCK);
28032810
Kind = CK_HLSLAggregateSplatCast;
28042811
return;
28052812
}

clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// splat from vec1 to vec
44
// CHECK-LABEL: call1
55
// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLAggregateSplatCast>
6+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' lvalue <FloatingToIntegral> part_of_explicit_cast
67
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
78
// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
89
export void call1() {

0 commit comments

Comments
 (0)