Skip to content

Commit f924b13

Browse files
committed
Self review
1 parent 0650840 commit f924b13

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27962796
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27972797
}
27982798
case CK_HLSLSplatCast: {
2799+
// This code should only handle splatting from vectors of length 1.
27992800
assert(DestTy->isVectorType() && "Destination type must be a vector.");
28002801
auto *DestVecTy = DestTy->getAs<VectorType>();
28012802
QualType SrcTy = E->getType();
28022803
SourceLocation Loc = CE->getExprLoc();
28032804
Value *V = Visit(const_cast<Expr *>(E));
2804-
if (auto *VecTy = SrcTy->getAs<VectorType>()) {
2805-
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2806-
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2807-
SrcTy = VecTy->getElementType();
2808-
}
2809-
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
2805+
assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
2806+
2807+
auto *VecTy = SrcTy->getAs<VectorType>();
2808+
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2809+
2810+
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2811+
SrcTy = VecTy->getElementType();
2812+
28102813
Value *Cast =
28112814
EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
28122815
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
28142814
return false;
28152815

28162816
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
2817-
if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
2818-
return false;
2819-
28202817
if (SrcVecTy)
28212818
SrcTy = SrcVecTy->getElementType();
28222819

2820+
// Src isn't a scalar or a vector of length 1
2821+
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
2822+
return false;
2823+
28232824
llvm::SmallVector<QualType> DestTypes;
28242825
BuildFlattenedTypeList(DestTy, DestTypes);
28252826

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
2+
3+
// splat from vec1 to vec
4+
// CHECK-LABEL: call1
5+
// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
6+
// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
7+
export void call1() {
8+
float1 A = {1.0};
9+
int3 B = (int3)A;
10+
}
11+
12+
struct S {
13+
int A;
14+
float B;
15+
int C;
16+
float D;
17+
};
18+
19+
// splat from scalar to aggregate
20+
// CHECK-LABEL: call2
21+
// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
22+
// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
23+
export void call2() {
24+
S s = (S)5;
25+
}

0 commit comments

Comments
 (0)