Skip to content

Commit 93c0450

Browse files
committed
fix tests + associated issues in code
1 parent cadd309 commit 93c0450

File tree

4 files changed

+23
-17
lines changed

4 files changed

+23
-17
lines changed

clang/lib/Sema/SemaCast.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,19 +2777,7 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27772777
? CheckedConversionKind::FunctionalCast
27782778
: CheckedConversionKind::CStyleCast;
27792779

2780-
// This case should not trigger on regular vector splat
27812780
QualType SrcTy = SrcExpr.get()->getType();
2782-
if (Self.getLangOpts().HLSL &&
2783-
Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
2784-
const VectorType *VT = SrcTy->getAs<VectorType>();
2785-
// change splat from vec1 case to splat from scalar
2786-
if (VT && VT->getNumElements() == 1)
2787-
SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
2788-
CK_HLSLVectorTruncation, VK_PRValue, nullptr, CCK);
2789-
Kind = CK_HLSLSplatCast;
2790-
return;
2791-
}
2792-
27932781
// This case should not trigger on regular vector cast, vector truncation
27942782
if (Self.getLangOpts().HLSL &&
27952783
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
@@ -2801,6 +2789,21 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
28012789
return;
28022790
}
28032791

2792+
// This case should not trigger on regular vector splat
2793+
// If the relative order of this and the HLSLElementWise cast checks
2794+
// are changed, it might change which cast handles what in a few cases
2795+
if (Self.getLangOpts().HLSL &&
2796+
Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
2797+
const VectorType *VT = SrcTy->getAs<VectorType>();
2798+
// change splat from vec1 case to splat from scalar
2799+
if (VT && VT->getNumElements() == 1)
2800+
SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
2801+
CK_HLSLVectorTruncation,
2802+
SrcExpr.get()->getValueKind(), nullptr, CCK);
2803+
Kind = CK_HLSLSplatCast;
2804+
return;
2805+
}
2806+
28042807
if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
28052808
!isPlaceholder(BuiltinType::Overload)) {
28062809
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2810,7 +2810,9 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
28102810
bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
28112811

28122812
QualType SrcTy = Src->getType();
2813-
if (SrcTy->isScalarType() && DestTy->isVectorType())
2813+
// Not a valid HLSL Splat cast if Dest is a scalar or if this is going to
2814+
// be a vector splat from a scalar.
2815+
if ((SrcTy->isScalarType() && DestTy->isVectorType()) || DestTy->isScalarType())
28142816
return false;
28152817

28162818
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();

clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ export void call4() {
2020
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
2121
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
2222
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
23+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
2324
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
2425
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
25-
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
2626
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
2727
// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
2828
export void call8() {
@@ -37,7 +37,7 @@ export void call8() {
3737
// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
3838
// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
3939
// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
40-
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
40+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
4141
// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
4242
// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
4343
// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
@@ -58,9 +58,9 @@ struct S {
5858
// CHECK: [[s:%.*]] = alloca %struct.S, align 4
5959
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
6060
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
61+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
6162
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
6263
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
63-
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
6464
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
6565
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
6666
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
@@ -75,9 +75,9 @@ export void call3() {
7575
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
7676
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
7777
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
78+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
7879
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
7980
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
80-
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
8181
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
8282
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
8383
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4

clang/test/SemaHLSL/Language/SplatCasts.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// splat from vec1 to vec
44
// CHECK-LABEL: call1
55
// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
6+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
67
// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
78
export void call1() {
89
float1 A = {1.0};

0 commit comments

Comments
 (0)