Skip to content

Commit c475f8e

Browse files
authored
[HLSL] Update vector swizzle elements individually (#169090)
When individual elements of a vector are updated via vector swizzle, it needs to be handled as separate store operations to the individual vector elements. Clang treats vectors as one unit, so if a part of a vector needs to be updated, the whole vector is loaded, some elements modified, and then the whole vector is stored. In HLSL vector elements are handled separately. We need to avoid this load/modify/store sequence to prevent overwriting other vector elements that might be getting updated in parallel. Fixes #152815
1 parent 1c9368e commit c475f8e

File tree

4 files changed

+158
-21
lines changed

4 files changed

+158
-21
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,26 +2801,56 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
28012801
LValue Dst) {
28022802
llvm::Value *SrcVal = Src.getScalarVal();
28032803
Address DstAddr = Dst.getExtVectorAddress();
2804+
const llvm::Constant *Elts = Dst.getExtVectorElts();
28042805
if (DstAddr.getElementType()->getScalarSizeInBits() >
28052806
SrcVal->getType()->getScalarSizeInBits())
28062807
SrcVal = Builder.CreateZExt(
28072808
SrcVal, convertTypeForLoadStore(Dst.getType(), SrcVal->getType()));
28082809

2809-
// HLSL allows storing to scalar values through ExtVector component LValues.
2810-
// To support this we need to handle the case where the destination address is
2811-
// a scalar.
2812-
if (!DstAddr.getElementType()->isVectorTy()) {
2813-
assert(!Dst.getType()->isVectorType() &&
2814-
"this should only occur for non-vector l-values");
2815-
Builder.CreateStore(SrcVal, DstAddr, Dst.isVolatileQualified());
2810+
if (getLangOpts().HLSL) {
2811+
llvm::Type *DestAddrTy = DstAddr.getElementType();
2812+
// HLSL allows storing to scalar values through ExtVector component LValues.
2813+
// To support this we need to handle the case where the destination address
2814+
// is a scalar.
2815+
if (!DestAddrTy->isVectorTy()) {
2816+
assert(!Dst.getType()->isVectorType() &&
2817+
"this should only occur for non-vector l-values");
2818+
Builder.CreateStore(SrcVal, DstAddr, Dst.isVolatileQualified());
2819+
return;
2820+
}
2821+
2822+
// HLSL allows direct access to vector elements, so storing to individual
2823+
// elements of a vector through ExtVector is handled as separate store
2824+
// instructions.
2825+
// If we are updating multiple elements, Dst and Src are vectors; for
2826+
// a single element update they are scalars.
2827+
const VectorType *VTy = Dst.getType()->getAs<VectorType>();
2828+
unsigned NumSrcElts = VTy ? VTy->getNumElements() : 1;
2829+
CharUnits ElemAlign = CharUnits::fromQuantity(
2830+
CGM.getDataLayout().getPrefTypeAlign(DestAddrTy->getScalarType()));
2831+
llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0);
2832+
2833+
for (unsigned I = 0; I != NumSrcElts; ++I) {
2834+
llvm::Value *Val = VTy ? Builder.CreateExtractElement(
2835+
SrcVal, llvm::ConstantInt::get(Int32Ty, I))
2836+
: SrcVal;
2837+
unsigned FieldNo = getAccessedFieldNo(I, Elts);
2838+
Address DstElemAddr = Address::invalid();
2839+
if (FieldNo == 0)
2840+
DstElemAddr = DstAddr.withAlignment(ElemAlign);
2841+
else
2842+
DstElemAddr = Builder.CreateGEP(
2843+
DstAddr, {Zero, llvm::ConstantInt::get(Int32Ty, FieldNo)},
2844+
DestAddrTy, ElemAlign);
2845+
Builder.CreateStore(Val, DstElemAddr, Dst.isVolatileQualified());
2846+
}
28162847
return;
28172848
}
28182849

28192850
// This access turns into a read/modify/write of the vector. Load the input
28202851
// value now.
28212852
llvm::Value *Vec = Builder.CreateLoad(DstAddr, Dst.isVolatileQualified());
28222853
llvm::Type *VecTy = Vec->getType();
2823-
const llvm::Constant *Elts = Dst.getExtVectorElts();
28242854

28252855
if (const VectorType *VTy = Dst.getType()->getAs<VectorType>()) {
28262856
unsigned NumSrcElts = VTy->getNumElements();

clang/test/CodeGenHLSL/BasicFeatures/OutputArguments.hlsl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,16 @@ void funky(inout int3 X) {
101101
// Call the function with the temporary.
102102
// CHECK: call void {{.*}}funky{{.*}}(ptr noalias noundef nonnull align 16 dereferenceable(16) [[ArgTmp]])
103103

104-
// Shuffle it back.
104+
// Write it back.
105105
// CHECK: [[RetVal:%.*]] = load <3 x i32>, ptr [[ArgTmp]]
106-
// CHECK: [[Vxyz:%.*]] = shufflevector <3 x i32> [[RetVal]], <3 x i32> poison, <3 x i32> <i32 2, i32 0, i32 1>
107-
// CHECK: store <3 x i32> [[Vxyz]], ptr [[V]]
106+
// CHECK: [[Src0:%.*]] = extractelement <3 x i32> [[RetVal]], i32 0
107+
// CHECK: [[PtrY:%.*]] = getelementptr <3 x i32>, ptr %V, i32 0, i32 1
108+
// CHECK: store i32 [[Src0]], ptr [[PtrY]], align 4
109+
// CHECK: [[Src1:%.*]] = extractelement <3 x i32> [[RetVal]], i32 1
110+
// CHECK: [[PtrZ:%.*]] = getelementptr <3 x i32>, ptr %V, i32 0, i32 2
111+
// CHECK: store i32 [[Src1]], ptr [[PtrZ]], align 4
112+
// CHECK: [[Src2:%.*]] = extractelement <3 x i32> [[RetVal]], i32 2
113+
// CHECK: store i32 [[Src2]], ptr %V, align 4
108114

109115
// OPT: ret <3 x i32> <i32 3, i32 1, i32 2>
110116
export int3 case4() {

clang/test/CodeGenHLSL/builtins/ScalarSwizzles.hlsl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,8 @@ bool AssignBool(bool V) {
259259
// CHECK-NEXT: [[B:%.*]] = load i32, ptr [[VAddr]], align 4
260260
// CHECK-NEXT: [[LV1:%.*]] = trunc i32 [[B]] to i1
261261
// CHECK-NEXT: [[D:%.*]] = zext i1 [[LV1]] to i32
262-
// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[X]], align 8
263-
// CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[C]], i32 [[D]], i32 1
264-
// CHECK-NEXT: store <2 x i32> [[E]], ptr [[X]], align 8
262+
// CHECK-NEXT: [[C:%.*]] = getelementptr <2 x i32>, ptr [[X]], i32 0, i32 1
263+
// CHECK-NEXT: store i32 [[D]], ptr [[C]], align 4
265264
// CHECK-NEXT: ret void
266265
void AssignBool2(bool V) {
267266
bool2 X = true.xx;
@@ -277,10 +276,13 @@ void AssignBool2(bool V) {
277276
// CHECK-NEXT: [[Z:%.*]] = load <2 x i32>, ptr [[VAddr]], align 8
278277
// CHECK-NEXT: [[LV:%.*]] = trunc <2 x i32> [[Z]] to <2 x i1>
279278
// CHECK-NEXT: [[B:%.*]] = zext <2 x i1> [[LV]] to <2 x i32>
280-
// CHECK-NEXT: [[A:%.*]] = load <2 x i32>, ptr [[X]], align 8
281-
// CHECK-NEXT: [[C:%.*]] = shufflevector <2 x i32> [[B]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
282-
// CHECK-NEXT: store <2 x i32> [[C]], ptr [[X]], align 8
279+
// CHECK-NEXT: [[V1:%.*]] = extractelement <2 x i32> [[B]], i32 0
280+
// CHECK-NEXT: store i32 [[V1]], ptr [[X]], align 4
281+
// CHECK-NEXT: [[V2:%.*]] = extractelement <2 x i32> [[B]], i32 1
282+
// CHECK-NEXT: [[X2:%.*]] = getelementptr <2 x i32>, ptr [[X]], i32 0, i32 1
283+
// CHECK-NEXT: store i32 [[V2]], ptr [[X2]], align 4
283284
// CHECK-NEXT: ret void
285+
284286
void AssignBool3(bool2 V) {
285287
bool2 X = {true,true};
286288
X.xy = V;
@@ -313,10 +315,13 @@ bool2 AccessBools() {
313315
// CHECK-NEXT: [[L1:%.*]] = shufflevector <1 x i32> [[L0]], <1 x i32> poison, <3 x i32> zeroinitializer
314316
// CHECK-NEXT: [[TruncV:%.*]] = trunc <3 x i32> [[L1]] to <3 x i1>
315317
// CHECK-NEXT: [[L2:%.*]] = zext <3 x i1> [[TruncV]] to <3 x i32>
316-
// CHECK-NEXT: [[L3:%.*]] = load <4 x i32>, ptr [[B]], align 16
317-
// CHECK-NEXT: [[L4:%.*]] = shufflevector <3 x i32> [[L2]], <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
318-
// CHECK-NEXT: [[L5:%.*]] = shufflevector <4 x i32> [[L3]], <4 x i32> [[L4]], <4 x i32> <i32 4, i32 5, i32 6, i32 3>
319-
// CHECK-NEXT: store <4 x i32> [[L5]], ptr [[B]], align 16
318+
// CHECK-NEXT: [[V1:%.*]] = extractelement <3 x i32> [[L2]], i32 0
319+
// CHECK-NEXT: store i32 [[V1]], ptr %B, align 4
320+
// CHECK-NEXT: [[V2:%.*]] = extractelement <3 x i32> [[L2]], i32 1
321+
// CHECK-NEXT: [[B2:%.*]] = getelementptr <4 x i32>, ptr %B, i32 0, i32 1
322+
// CHECK-NEXT: store i32 [[V2]], ptr [[B2]], align 4
323+
// CHECK-NEXT: [[V3:%.*]] = extractelement <3 x i32> [[L2]], i32 2
324+
// CHECK-NEXT: [[B3:%.*]] = getelementptr <4 x i32>, ptr %B, i32 0, i32 2
320325
void BoolSizeMismatch() {
321326
bool4 B = {true,true,true,true};
322327
B.xyz = false.xxx;
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type \
2+
// RUN: -triple dxil-pc-shadermodel6.3-library %s -disable-llvm-passes \
3+
// RUN: -emit-llvm -o - | FileCheck %s
4+
5+
// CHECK-LABEL: Single
6+
7+
// Setup local vars.
8+
// CHECK: [[VecAddr:%.*]] = alloca <3 x i64>, align 32
9+
// CHECK-NEXT: [[AAddr:%.*]] = alloca i64, align 8
10+
// CHECK-NEXT: store <3 x i64> %vec, ptr [[VecAddr]], align 32
11+
// CHECK-NEXT: store i64 %a, ptr [[AAddr]], align 8
12+
13+
// Update single element of the vector.
14+
// CHECK-NEXT: [[A:%.*]] = load i64, ptr [[AAddr]], align 8
15+
// CHECK-NEXT: [[Vy:%.*]] = getelementptr <3 x i64>, ptr [[VecAddr]], i32 0, i32 1
16+
// CHECK-NEXT: store i64 [[A]], ptr [[Vy]], align 8
17+
18+
// Return.
19+
// CHECK-NEXT: [[RetVal:%.*]] = load <3 x i64>, ptr [[VecAddr]], align 32
20+
// CHECK-NEXT: ret <3 x i64> [[RetVal]]
21+
uint64_t3 Single(uint64_t3 vec, uint64_t a){
22+
vec.y = a;
23+
return vec;
24+
}
25+
26+
// CHECK-LABEL: Double
27+
28+
// Setup local vars.
29+
// CHECK: [[VecAddr:%.*]] = alloca <3 x float>, align 16
30+
// CHECK-NEXT: [[AAddr:%.*]] = alloca float, align 4
31+
// CHECK-NEXT: [[BAddr:%.*]] = alloca float, align 4
32+
// CHECK-NEXT: store <3 x float> %vec, ptr [[VecAddr]], align 16
33+
// CHECK-NEXT: store float %a, ptr [[AAddr]], align 4
34+
// CHECK-NEXT: store float %b, ptr [[BAddr]], align 4
35+
36+
// Create temporary vector {a, b}.
37+
// CHECK-NEXT: [[A:%.*]] = load float, ptr [[AAddr]], align 4
38+
// CHECK-NEXT: [[TmpVec0:%.*]] = insertelement <2 x float> poison, float [[A]], i32 0
39+
// CHECK-NEXT: [[B:%.*]] = load float, ptr [[BAddr]], align 4
40+
// CHECK-NEXT: [[TmpVec1:%.*]] = insertelement <2 x float> [[TmpVec0]], float [[B]], i32 1
41+
42+
// Update two elements of the vector from temporary vector.
43+
// CHECK-NEXT: [[TmpX:%.*]] = extractelement <2 x float> [[TmpVec1]], i32 0
44+
// CHECK-NEXT: [[VecZ:%.*]] = getelementptr <3 x float>, ptr [[VecAddr]], i32 0, i32 2
45+
// CHECK-NEXT: store float [[TmpX]], ptr [[VecZ]], align 4
46+
// CHECK-NEXT: [[TmpY:%.*]] = extractelement <2 x float> [[TmpVec1]], i32 1
47+
// CHECK-NEXT: [[VecY:%.*]] = getelementptr <3 x float>, ptr [[VecAddr]], i32 0, i32 1
48+
// CHECK-NEXT: store float [[TmpY]], ptr [[VecY]], align 4
49+
50+
// Return.
51+
// CHECK-NEXT: [[RetVal:%.*]] = load <3 x float>, ptr [[VecAddr]], align 16
52+
// CHECK-NEXT: ret <3 x float> [[RetVal]]
53+
float3 Double(float3 vec, float a, float b) {
54+
vec.zy = {a, b};
55+
return vec;
56+
}
57+
58+
// CHECK-LABEL: Shuffle
59+
60+
// Setup local vars.
61+
// CHECK: [[VecAddr:%.*]] = alloca <4 x half>, align 8
62+
// CHECK-NEXT: [[AAddr:%.*]] = alloca half, align 2
63+
// CHECK-NEXT: [[BAddr:%.*]] = alloca half, align 2
64+
// CHECK-NEXT: store <4 x half> %vec, ptr [[VecAddr]], align 8
65+
// CHECK-NEXT: store half %a, ptr [[AAddr]], align 2
66+
// CHECK-NEXT: store half %b, ptr [[BAddr]], align 2
67+
68+
// Create temporary vector {a, b, 13.74, a}.
69+
// CHECK-NEXT: [[A:%.*]] = load half, ptr [[AAddr]], align 2
70+
// CHECK-NEXT: [[TmpVec0:%.*]] = insertelement <4 x half> poison, half [[A]], i32 0
71+
// CHECK-NEXT: [[B:%.*]] = load half, ptr [[BAddr]], align 2
72+
// CHECK-NEXT: [[TmpVec1:%.*]] = insertelement <4 x half> [[TmpVec0]], half [[B]], i32 1
73+
// CHECK-NEXT: [[TmpVec2:%.*]] = insertelement <4 x half> %vecinit1, half 0xH4ADF, i32 2
74+
// CHECK-NEXT: [[A:%.*]] = load half, ptr [[AAddr]], align 2
75+
// CHECK-NEXT: [[TmpVec3:%.*]] = insertelement <4 x half> [[TmpVec2]], half [[A]], i32 3
76+
77+
// Update four elements of the vector via mixed up swizzle from the temporary vector.
78+
// CHECK-NEXT: [[TmpX:%.*]] = extractelement <4 x half> [[TmpVec3]], i32 0
79+
// CHECK-NEXT: [[VecZ:%.*]] = getelementptr <4 x half>, ptr [[VecAddr]], i32 0, i32 2
80+
// CHECK-NEXT: store half [[TmpX]], ptr [[VecZ]], align 2
81+
// CHECK-NEXT: [[TmpY:%.*]] = extractelement <4 x half> [[TmpVec3]], i32 1
82+
// CHECK-NEXT: [[VecW:%.*]] = getelementptr <4 x half>, ptr [[VecAddr]], i32 0, i32 3
83+
// CHECK-NEXT: store half [[TmpY]], ptr [[VecW]], align 2
84+
// CHECK-NEXT: [[TmpZ:%.*]] = extractelement <4 x half> [[TmpVec3]], i32 2
85+
// CHECK-NEXT: store half [[TmpZ]], ptr [[VecAddr]], align 2
86+
// CHECK-NEXT: [[TmpW:%.*]] = extractelement <4 x half> [[TmpVec3]], i32 3
87+
// CHECK-NEXT: [[VecY:%.*]] = getelementptr <4 x half>, ptr [[VecAddr]], i32 0, i32 1
88+
// CHECK-NEXT: store half [[TmpW]], ptr [[VecY]], align 2
89+
90+
// Return.
91+
// CHECK-NEXT: [[RetVal:%.*]] = load <4 x half>, ptr [[VecAddr]], align 8
92+
// CHECK-NEXT: ret <4 x half> [[RetVal]]
93+
half4 Shuffle(half4 vec, half a, half b) {
94+
vec.zwxy = {a, b, 13.74, a};
95+
return vec;
96+
}

0 commit comments

Comments
 (0)