Skip to content

[HLSL] Implement HLSL Aggregate splatting #118992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLElementwiseCast)

// Splat cast for Aggregates (HLSL only).
CAST_OPERATION(HLSLSplatCast)

//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool ContainsBitField(QualType BaseTy);
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
bool CanPerformSplatCast(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);

QualType getInoutParameterType(QualType Ty);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const {
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointCast:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for integral value");

case CK_BitCast:
Expand Down Expand Up @@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,31 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
llvm::Value *Cast =
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);

// store back
llvm::Value *Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
}
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
Expand Down Expand Up @@ -963,6 +988,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;
case CK_HLSLSplatCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();

assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
break;
}
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
Expand Down Expand Up @@ -1553,6 +1591,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
return true;

case CK_BaseToDerivedMemberPointer:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,7 @@ class ConstExprEmitter
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
16 changes: 16 additions & 0 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLSplatCast: {
// This cast should only handle splatting from vectors of length 1.
// But in Sema a cast should have been inserted to convert the vec1 to a
// scalar
assert(DestTy->isVectorType() && "Destination type must be a vector.");
auto *DestVecTy = DestTy->getAs<VectorType>();
QualType SrcTy = E->getType();
SourceLocation Loc = CE->getExprLoc();
Value *V = Visit(const_cast<Expr *>(E));
assert(SrcTy->isBuiltinType() && "Invalid HLSL splat cast.");

Value *Cast =
EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
"splat");
}
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,

case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;

Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
case CK_ToVoid:
case CK_NonAtomicToAtomic:
case CK_HLSLArrayRValue:
case CK_HLSLSplatCast:
break;
}
}
Expand Down
19 changes: 17 additions & 2 deletions clang/lib/Sema/SemaCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2776,9 +2776,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
// This case should not trigger on regular vector splat
// vector cast, vector truncation, or special hlsl splat cases

QualType SrcTy = SrcExpr.get()->getType();
// This case should not trigger on regular vector cast, vector truncation
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
if (SrcTy->isConstantArrayType())
Expand All @@ -2789,6 +2789,21 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
return;
}

// This case should not trigger on regular vector splat
// If the relative order of this and the HLSLElementWise cast checks
// are changed, it might change which cast handles what in a few cases
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
const VectorType *VT = SrcTy->getAs<VectorType>();
// change splat from vec1 case to splat from scalar
if (VT && VT->getNumElements() == 1)
SrcExpr = Self.ImpCastExprToType(
SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
SrcExpr.get()->getValueKind(), nullptr, CCK);
Kind = CK_HLSLSplatCast;
return;
}

if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
!isPlaceholder(BuiltinType::Overload)) {
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
Expand Down
38 changes: 37 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2771,7 +2771,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
}

// Detect if a type contains a bitfield. Will be removed when
// bitfield support is added to HLSLElementwiseCast
// bitfield support is added to HLSLElementwiseCast and HLSLSplatCast
bool SemaHLSL::ContainsBitField(QualType BaseTy) {
llvm::SmallVector<QualType, 16> WorkList;
WorkList.push_back(BaseTy);
Expand Down Expand Up @@ -2804,6 +2804,42 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
return false;
}

// Can perform an HLSL splat cast if the Dest is an aggregate and the
// Src is a scalar or a vector of length 1
// Or if Dest is a vector and Src is a vector of length 1
bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {

QualType SrcTy = Src->getType();
// Not a valid HLSL Splat cast if Dest is a scalar or if this is going to
// be a vector splat from a scalar.
if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
DestTy->isScalarType())
return false;

const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();

// Src isn't a scalar or a vector of length 1
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
return false;

if (SrcVecTy)
SrcTy = SrcVecTy->getElementType();

if (ContainsBitField(DestTy))
return false;

llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);

for (unsigned I = 0, Size = DestTypes.size(); I < Size; I++) {
if (DestTypes[I]->isUnionType())
return false;
if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
return false;
}
return true;
}

// Can we perform an HLSL Elementwise cast?
// TODO: update this code when matrices are added; see issue #88060
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
Expand Down
1 change: 1 addition & 0 deletions clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
case CK_MatrixCast:
case CK_VectorSplat:
case CK_HLSLElementwiseCast:
case CK_HLSLSplatCast:
case CK_HLSLVectorTruncation: {
QualType resultType = CastE->getType();
if (CastE->isGLValue())
Expand Down
87 changes: 87 additions & 0 deletions clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s

// array splat
// CHECK-LABEL: define void {{.*}}call4
// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
export void call4() {
int B[2] = {1,2};
B = (int[2])3;
}

// splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call8
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
export void call8() {
int1 A = {1};
int B[2] = {1,2};
B = (int[2])A;
}

// vector splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call1
// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
export void call1() {
float1 B = {1.0};
int4 A = (int4)B;
}

struct S {
int X;
float Y;
};

// struct splats?
// CHECK-LABEL: define void {{.*}}call3
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
export void call3() {
int1 A = {1};
S s = (S)A;
}

// struct splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call5
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
export void call5() {
int1 A = {1};
S s = (S)A;
}
20 changes: 20 additions & 0 deletions clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ export void cantCast3() {
S s = (S)C;
// expected-error@-1 {{no matching conversion for C-style cast from 'int2' (aka 'vector<int, 2>') to 'S'}}
}

struct R {
// expected-note@-1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const R' for 1st argument}}
// expected-note@-2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'R' for 1st argument}}
// expected-note@-3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
int A;
union {
float F;
int4 G;
};
};

export void cantCast4() {
int2 A = {1,2};
R r = R(A);
// expected-error@-1 {{no matching conversion for functional-style cast from 'int2' (aka 'vector<int, 2>') to 'R'}}
R r2 = {1, 2};
int2 B = (int2)r2;
// expected-error@-1 {{cannot convert 'R' to 'int2' (aka 'vector<int, 2>') without a conversion operator}}
}
Loading
Loading