Skip to content

Commit 4a48740

Browse files
authored
[HLSL] Update indexed vector elements individually (#169144)
When an individual element of a vector is updated via indexing into the vector, it needs to be handled as a store operation on that one vector element. Clang treats vectors as one unit, so a vector element needs to be updated, the whole vector is loaded, the element is modified, and then the whole vector is stored. In HLSL vector elements are handled individually. We need to avoid this load/modify/store sequence to prevent overwriting other vector elements that might be getting updated in parallel. Fixes #167729 Contributes to #160208.
1 parent 56d061c commit 4a48740

File tree

4 files changed

+77
-11
lines changed

4 files changed

+77
-11
lines changed

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,32 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
25752575
bool isInit) {
25762576
if (!Dst.isSimple()) {
25772577
if (Dst.isVectorElt()) {
2578+
if (getLangOpts().HLSL) {
2579+
// HLSL allows direct access to vector elements, so storing to
2580+
// individual elements of a vector through VectorElt is handled as
2581+
// separate store instructions.
2582+
Address DstAddr = Dst.getVectorAddress();
2583+
llvm::Type *DestAddrTy = DstAddr.getElementType();
2584+
llvm::Type *ElemTy = DestAddrTy->getScalarType();
2585+
CharUnits ElemAlign = CharUnits::fromQuantity(
2586+
CGM.getDataLayout().getPrefTypeAlign(ElemTy));
2587+
2588+
assert(ElemTy->getScalarSizeInBits() >= 8 &&
2589+
"vector element type must be at least byte-sized");
2590+
2591+
llvm::Value *Val = Src.getScalarVal();
2592+
if (Val->getType()->getPrimitiveSizeInBits() <
2593+
ElemTy->getScalarSizeInBits())
2594+
Val = Builder.CreateZExt(Val, ElemTy->getScalarType());
2595+
2596+
llvm::Value *Idx = Dst.getVectorIdx();
2597+
llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0);
2598+
Address DstElemAddr =
2599+
Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, ElemAlign);
2600+
Builder.CreateStore(Val, DstElemAddr, Dst.isVolatileQualified());
2601+
return;
2602+
}
2603+
25782604
// Read/modify/write the vector, inserting the new element.
25792605
llvm::Value *Vec = Builder.CreateLoad(Dst.getVectorAddress(),
25802606
Dst.isVolatileQualified());

clang/test/CodeGenHLSL/BoolVector.hlsl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ bool fn4() {
6969
// CHECK-LABEL: define hidden void {{.*}}fn5{{.*}}
7070
// CHECK: [[Arr:%.*]] = alloca <2 x i32>, align 8
7171
// CHECK-NEXT: store <2 x i32> splat (i32 1), ptr [[Arr]], align 8
72-
// CHECK-NEXT: [[L:%.*]] = load <2 x i32>, ptr [[Arr]], align 8
73-
// CHECK-NEXT: [[V:%.*]] = insertelement <2 x i32> [[L]], i32 0, i32 1
74-
// CHECK-NEXT: store <2 x i32> [[V]], ptr [[Arr]], align 8
72+
// CHECK-NEXT: [[Ptr:%.*]] = getelementptr <2 x i32>, ptr [[Arr]]
73+
// CHECK-NEXT: store i32 0, ptr [[Ptr]], align 4
7574
// CHECK-NEXT: ret void
7675
void fn5() {
7776
bool2 Arr = {true,true};
@@ -86,10 +85,9 @@ void fn5() {
8685
// CHECK-NEXT: [[Y:%.*]] = load i32, ptr [[V]], align 4
8786
// CHECK-NEXT: [[LV:%.*]] = trunc i32 [[Y]] to i1
8887
// CHECK-NEXT: [[BV:%.*]] = getelementptr inbounds nuw %struct.S, ptr [[S]], i32 0, i32 0
89-
// CHECK-NEXT: [[X:%.*]] = load <2 x i32>, ptr [[BV]], align 1
9088
// CHECK-NEXT: [[Z:%.*]] = zext i1 [[LV]] to i32
91-
// CHECK-NEXT: [[VI:%.*]] = insertelement <2 x i32> [[X]], i32 [[Z]], i32 1
92-
// CHECK-NEXT: store <2 x i32> [[VI]], ptr [[BV]], align 1
89+
// CHECK-NEXT: [[Ptr:%.*]] = getelementptr <2 x i32>, ptr [[BV]], i32 0, i32 1
90+
// CHECK-NEXT: store i32 [[Z]], ptr [[Ptr]], align 4
9391
// CHECK-NEXT: ret void
9492
void fn6() {
9593
bool V = false;
@@ -101,9 +99,8 @@ void fn6() {
10199
// CHECK: [[Arr:%.*]] = alloca [2 x <2 x i32>], align 8
102100
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 8 [[Arr]], ptr align 8 {{.*}}, i32 16, i1 false)
103101
// CHECK-NEXT: [[Idx:%.*]] = getelementptr inbounds [2 x <2 x i32>], ptr [[Arr]], i32 0, i32 0
104-
// CHECK-NEXT: [[X:%.*]] = load <2 x i32>, ptr [[Idx]], align 8
105-
// CHECK-NEXT: [[VI:%.*]] = insertelement <2 x i32> [[X]], i32 0, i32 1
106-
// CHECK-NEXT: store <2 x i32> [[VI]], ptr [[Idx]], align 8
102+
// CHECK-NEXT: %[[Ptr:.*]] = getelementptr <2 x i32>, ptr [[Idx]], i32 0, i32 1
103+
// CHECK-NEXT: store i32 0, ptr %[[Ptr]], align 4
107104
// CHECK-NEXT: ret void
108105
void fn7() {
109106
bool2 Arr[2] = {{true,true}, {false,false}};
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clang_cc1 -finclude-default-header -emit-llvm -disable-llvm-passes \
2+
// RUN: -triple dxil-pc-shadermodel6.3-library %s -o - | FileCheck %s
3+
4+
// Test groupshared vector element store for uint.
5+
// CHECK-LABEL: test_uint4
6+
// CHECK: [[VAL:%.*]] = load i32, ptr %Val.addr, align 4
7+
// CHECK: [[IDX:%.*]] = load i32, ptr %Idx.addr, align 4
8+
// CHECK: [[PTR:%.*]] = getelementptr <4 x i32>, ptr addrspace(3) @SMem, i32 0, i32 [[IDX]]
9+
// CHECK: store i32 [[VAL]], ptr addrspace(3) [[PTR]], align 4
10+
// CHECK-: ret void
11+
groupshared uint4 SMem;
12+
void test_uint4(uint Idx, uint Val) {
13+
SMem[Idx] = Val;
14+
}
15+
16+
// Test local vector element store for bool.
17+
// CHECK: [[COND1:%.*]] = load i32, ptr addrspace(3) @Cond, align 4
18+
// CHECK: [[COND2:%.*]] = trunc i32 [[COND1]] to i1
19+
// CHECK: [[IDX:%.*]] = load i32, ptr %Idx.addr, align 4
20+
// CHECK: [[COND3:%.*]] = zext i1 [[COND2]] to i32
21+
// CHECK: [[PTR:%.*]] = getelementptr <3 x i32>, ptr %Val, i32 0, i32 [[IDX]]
22+
// CHECK: store i32 [[COND3]], ptr [[PTR]], align 4
23+
// CHECK: ret
24+
groupshared bool Cond;
25+
bool3 test_bool(uint Idx) {
26+
bool3 Val = { false, false, false};
27+
Val[Idx] = Cond;
28+
return Val;
29+
}
30+
31+
// Test resource vector element store for float.
32+
// CHECK: [[VAL:%.*]] = load float, ptr %Val.addr, align 4
33+
// CHECK: [[RES_PTR:%.*]] = call {{.*}} ptr @_ZN4hlsl18RWStructuredBufferIDv4_fEixEj(ptr {{.*}} @_ZL3Buf, i32 noundef 0)
34+
// CHECK: [[IDX:%.*]] = load i32, ptr %Idx.addr, align 4
35+
// CHECK: [[PTR:%.*]] = getelementptr <4 x float>, ptr [[RES_PTR]], i32 0, i32 [[IDX]]
36+
// CHECK: store float [[VAL]], ptr [[PTR]], align 4
37+
// CHECK: ret void
38+
RWStructuredBuffer<float4> Buf : register(u0);
39+
void test_float(uint Idx, float Val) {
40+
Buf[0][Idx] = Val;
41+
}

clang/test/CodeGenHLSL/builtins/lit.hlsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
// CHECK: %mul.i = fmul reassoc nnan ninf nsz arcp afn half [[LOG]], %{{.*}}
1212
// CHECK: [[EXP:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.exp.f16(half %mul.i)
1313
// CHECK: %hlsl.select7.i = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, half 0xH0000, half %{{.*}}
14-
// CHECK: %vecins.i = insertelement <4 x half> %{{.*}}, half %hlsl.select7.i, i32 2
14+
// CHECK: [[PTR:%.*]] = getelementptr <4 x half>, ptr %Result.i, i32 0, i32 2
15+
// CHECK: store half %hlsl.select7.i, ptr [[PTR]], align 2
1516
// CHECK: ret <4 x half> %{{.*}}
1617
half4 test_lit_half(half NDotL, half NDotH, half M) { return lit(NDotL, NDotH, M); }
1718

@@ -26,6 +27,7 @@ half4 test_lit_half(half NDotL, half NDotH, half M) { return lit(NDotL, NDotH, M
2627
// CHECK: %mul.i = fmul reassoc nnan ninf nsz arcp afn float [[LOG]], %{{.*}}
2728
// CHECK: [[EXP:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.exp.f32(float %mul.i)
2829
// CHECK: %hlsl.select7.i = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, float 0.000000e+00, float %{{.*}}
29-
// CHECK: %vecins.i = insertelement <4 x float> %{{.*}}, float %hlsl.select7.i, i32 2
30+
// CHECK: [[PTR:%.*]] = getelementptr <4 x float>, ptr %Result.i, i32 0, i32 2
31+
// CHECK: store float %hlsl.select7.i, ptr [[PTR]], align 4
3032
// CHECK: ret <4 x float> %{{.*}}
3133
float4 test_lit_float(float NDotL, float NDotH, float M) { return lit(NDotL, NDotH, M); }

0 commit comments

Comments
 (0)