Skip to content

Commit 7d74125

Browse files
committed
Self review
1 parent 245367c commit 7d74125

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,17 +2788,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27882788
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27892789
}
27902790
case CK_HLSLSplatCast: {
2791+
// This code should only handle splatting from vectors of length 1.
27912792
assert(DestTy->isVectorType() && "Destination type must be a vector.");
27922793
auto *DestVecTy = DestTy->getAs<VectorType>();
27932794
QualType SrcTy = E->getType();
27942795
SourceLocation Loc = CE->getExprLoc();
27952796
Value *V = Visit(const_cast<Expr *>(E));
2796-
if (auto *VecTy = SrcTy->getAs<VectorType>()) {
2797-
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2798-
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2799-
SrcTy = VecTy->getElementType();
2800-
}
2801-
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
2797+
assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
2798+
2799+
auto *VecTy = SrcTy->getAs<VectorType>();
2800+
assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
2801+
2802+
V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
2803+
SrcTy = VecTy->getElementType();
2804+
28022805
Value *Cast =
28032806
EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
28042807
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,12 +2486,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
24862486
return false;
24872487

24882488
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
2489-
if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
2490-
return false;
2491-
24922489
if (SrcVecTy)
24932490
SrcTy = SrcVecTy->getElementType();
24942491

2492+
// Src isn't a scalar or a vector of length 1
2493+
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
2494+
return false;
2495+
24952496
llvm::SmallVector<QualType> DestTypes;
24962497
BuildFlattenedTypeList(DestTy, DestTypes);
24972498

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
2+
3+
// splat from vec1 to vec
4+
// CHECK-LABEL: call1
5+
// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
6+
// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
7+
export void call1() {
8+
float1 A = {1.0};
9+
int3 B = (int3)A;
10+
}
11+
12+
struct S {
13+
int A;
14+
float B;
15+
int C;
16+
float D;
17+
};
18+
19+
// splat from scalar to aggregate
20+
// CHECK-LABEL: call2
21+
// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
22+
// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
23+
export void call2() {
24+
S s = (S)5;
25+
}

0 commit comments

Comments
 (0)