Skip to content

Commit e2f0b98

Browse files
authored
Don't reuse WaveGetLaneIndex result across thread repacking points (microsoft#5607)
Wave intrinsics such as `WaveGetLaneIndex()` are invalidated at DXR thread repacking points such as `CallShader()`. We were however reusing the result of `WaveGetLaneIndex()`. Fix this by marking it as `Readonly` instead of `Readnone`. Add a test case that also covers other wave intrisics, which are handled correctly. Fixes microsoft#5034. --------- Co-authored-by: Jannik Silvanus <[email protected]>
1 parent 07ce880 commit e2f0b98

File tree

4 files changed

+303
-3
lines changed

4 files changed

+303
-3
lines changed

lib/DXIL/DxilOperations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
11611161
"waveGetLaneIndex",
11621162
{true, false, false, false, false, false, false, false, false, false,
11631163
false},
1164-
Attribute::ReadNone,
1164+
Attribute::ReadOnly,
11651165
},
11661166
{
11671167
OC::WaveGetLaneCount,
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
// RUN: %dxc -DREPACK_POINT_KIND=1 -T lib_6_5 %s | FileCheck %s
2+
// RUN: %dxc -DREPACK_POINT_KIND=2 -T lib_6_5 %s | FileCheck %s
3+
// RUN: %dxc -DREPACK_POINT_KIND=3 -T lib_6_5 %s | FileCheck %s
4+
5+
// Check that results of wave intrinsics are not re-used cross DXR repacking points.
6+
7+
#define REPACK_POINT_KIND_TRACERAY 1
8+
#define REPACK_POINT_KIND_CALLSHADER 2
9+
#define REPACK_POINT_KIND_REPORTHIT 3
10+
11+
struct Payload {
12+
unsigned int value;
13+
};
14+
15+
struct HitAttributes {
16+
unsigned int value;
17+
};
18+
19+
RaytracingAccelerationStructure myAccelerationStructure : register(t3);
20+
21+
// Helper to introduce a repacking point, passing the identifier as argument
22+
// so we can find it in the generated DXIL.
23+
// dep is used to introduce a dependency of the packing point
24+
// on the passed value, and the returned value is guaranteed to depend
25+
// on the result of the repacking point.
26+
unsigned int RepackingPoint(unsigned int dependency, int identifier) {
27+
unsigned int result = dependency;
28+
#if REPACK_POINT_KIND == REPACK_POINT_KIND_CALLSHADER
29+
Payload p;
30+
p.value = dependency;
31+
CallShader(identifier, p);
32+
result += p.value;
33+
#elif REPACK_POINT_KIND == REPACK_POINT_KIND_TRACERAY
34+
Payload p;
35+
p.value = dependency;
36+
RayDesc myRay = { float3(0., 0., 0.), 0., float3(0., 0., 0.), 1.0};
37+
TraceRay(myAccelerationStructure, 0, -1, 0, 0, identifier, myRay, p);
38+
result += p.value;
39+
#elif REPACK_POINT_KIND == REPACK_POINT_KIND_REPORTHIT
40+
HitAttributes attrs;
41+
attrs.value = dependency;
42+
bool didAccept = ReportHit(0.0, identifier, attrs);
43+
if (didAccept) {
44+
result += 1;
45+
}
46+
#else
47+
#error "Unknown repack point kind"
48+
#endif
49+
return result;
50+
}
51+
52+
RWBuffer<unsigned int> output : register(u0, space0);
53+
54+
// Calls wave intrinsics before and after repacking points, and checks
55+
// that both calls remain, as re-using the result from before the repacking
56+
// point is invalid, because threads may be re-packed in between.
57+
#if (REPACK_POINT_KIND == REPACK_POINT_KIND_TRACERAY) || \
58+
(REPACK_POINT_KIND == REPACK_POINT_KIND_CALLSHADER)
59+
[shader("miss")] void Miss(inout Payload p) {
60+
#else // REPACK_POINT_KIND_REPORTHIT
61+
[shader("intersection")] void Intersection() {
62+
#endif
63+
// Opaque value the compiler cannot reason about to prevent optimizations.
64+
// At the end we store the resulting value back to the buffer so the
65+
// test code cannot be optimized out.
66+
unsigned int opaque = output[DispatchRaysIndex().x];
67+
// Passed as argument to wave intrinsics taking an argument to
68+
// ensure repeated calls to intrinsics use the same argument.
69+
// Otherwise the argument being different already prevents re-use,
70+
// rendering the test pointless.
71+
unsigned int commonArg = opaque;
72+
73+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 0
74+
// CHECK: @dx.op.waveIsFirstLane(i32 110
75+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 1
76+
// CHECK: @dx.op.waveIsFirstLane(i32 110
77+
opaque += RepackingPoint(opaque, 0);
78+
opaque += WaveIsFirstLane();
79+
opaque += RepackingPoint(opaque, 1);
80+
opaque += WaveIsFirstLane();
81+
82+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 2
83+
// CHECK: @dx.op.waveGetLaneIndex(i32 111
84+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 3
85+
// CHECK: @dx.op.waveGetLaneIndex(i32 111
86+
opaque += RepackingPoint(opaque, 2);
87+
opaque += WaveGetLaneIndex();
88+
opaque += RepackingPoint(opaque, 3);
89+
opaque += WaveGetLaneIndex();
90+
91+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 4
92+
// CHECK: @dx.op.waveAnyTrue(i32 113, i1
93+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 5
94+
// CHECK: @dx.op.waveAnyTrue(i32 113, i1
95+
opaque += RepackingPoint(opaque, 4);
96+
opaque += WaveActiveAnyTrue(commonArg == 17);
97+
opaque += RepackingPoint(opaque, 5);
98+
opaque += WaveActiveAnyTrue(commonArg == 17);
99+
100+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 6
101+
// CHECK: @dx.op.waveAllTrue(i32 114, i1
102+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 7
103+
// CHECK: @dx.op.waveAllTrue(i32 114, i1
104+
opaque += RepackingPoint(opaque, 6);
105+
opaque += WaveActiveAllTrue(commonArg == 17);
106+
opaque += RepackingPoint(opaque, 7);
107+
opaque += WaveActiveAllTrue(commonArg == 17);
108+
109+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 8
110+
// CHECK: @dx.op.waveActiveAllEqual.i32(i32 115, i32
111+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 9
112+
// CHECK: @dx.op.waveActiveAllEqual.i32(i32 115, i32
113+
opaque += RepackingPoint(opaque, 8);
114+
opaque += WaveActiveAllEqual(commonArg);
115+
opaque += RepackingPoint(opaque, 9);
116+
opaque += WaveActiveAllEqual(commonArg);
117+
118+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 10
119+
// CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1
120+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 11
121+
// CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1
122+
opaque += RepackingPoint(opaque, 10);
123+
opaque += WaveActiveBallot(commonArg).x;
124+
opaque += RepackingPoint(opaque, 11);
125+
opaque += WaveActiveBallot(commonArg).x;
126+
127+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 12
128+
// CHECK: @dx.op.waveReadLaneAt.i32(i32 117, i32
129+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 13
130+
// CHECK: @dx.op.waveReadLaneAt.i32(i32 117, i32
131+
opaque += RepackingPoint(opaque, 12);
132+
opaque += WaveReadLaneAt(commonArg, 1);
133+
opaque += RepackingPoint(opaque, 13);
134+
opaque += WaveReadLaneAt(commonArg, 1);
135+
136+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 14
137+
// CHECK: @dx.op.waveReadLaneFirst.i32(i32 118, i32
138+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 15
139+
// CHECK: @dx.op.waveReadLaneFirst.i32(i32 118, i32
140+
opaque += RepackingPoint(opaque, 14);
141+
opaque += WaveReadLaneFirst(commonArg);
142+
opaque += RepackingPoint(opaque, 15);
143+
opaque += WaveReadLaneFirst(commonArg);
144+
145+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 16
146+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
147+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 17
148+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
149+
opaque += RepackingPoint(opaque, 16);
150+
opaque += WaveActiveSum(commonArg);
151+
opaque += RepackingPoint(opaque, 17);
152+
opaque += WaveActiveSum(commonArg);
153+
154+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 18
155+
// CHECK: @dx.op.waveActiveOp.i64(i32 119, i64
156+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 19
157+
// CHECK: @dx.op.waveActiveOp.i64(i32 119, i64
158+
opaque += RepackingPoint(opaque, 18);
159+
opaque += WaveActiveProduct(commonArg == 17 ? 1 : 0);
160+
opaque += RepackingPoint(opaque, 19);
161+
opaque += WaveActiveProduct(commonArg == 17 ? 1 : 0);
162+
163+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 20
164+
// CHECK: @dx.op.waveActiveBit.i32(i32 120, i32
165+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 21
166+
// CHECK: @dx.op.waveActiveBit.i32(i32 120, i32
167+
opaque += RepackingPoint(opaque, 20);
168+
opaque += WaveActiveBitAnd(commonArg);
169+
opaque += RepackingPoint(opaque, 21);
170+
opaque += WaveActiveBitAnd(commonArg);
171+
172+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 22
173+
// CHECK: @dx.op.waveActiveBit.i32(i32 120, i32
174+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 23
175+
// CHECK: @dx.op.waveActiveBit.i32(i32 120, i32
176+
opaque += RepackingPoint(opaque, 22);
177+
opaque += WaveActiveBitXor(commonArg);
178+
opaque += RepackingPoint(opaque, 23);
179+
opaque += WaveActiveBitXor(commonArg);
180+
181+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 24
182+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
183+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 25
184+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
185+
opaque += RepackingPoint(opaque, 24);
186+
opaque += WaveActiveMin(commonArg);
187+
opaque += RepackingPoint(opaque, 25);
188+
opaque += WaveActiveMin(commonArg);
189+
190+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 26
191+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
192+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 27
193+
// CHECK: @dx.op.waveActiveOp.i32(i32 119, i32
194+
opaque += RepackingPoint(opaque, 26);
195+
opaque += WaveActiveMax(commonArg);
196+
opaque += RepackingPoint(opaque, 27);
197+
opaque += WaveActiveMax(commonArg);
198+
199+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 28
200+
// CHECK: @dx.op.wavePrefixOp.i32(i32 121, i32
201+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 29
202+
// CHECK: @dx.op.wavePrefixOp.i32(i32 121, i32
203+
opaque += RepackingPoint(opaque, 28);
204+
opaque += WavePrefixSum(commonArg);
205+
opaque += RepackingPoint(opaque, 29);
206+
opaque += WavePrefixSum(commonArg);
207+
208+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 30
209+
// CHECK: @dx.op.wavePrefixOp.i64(i32 121, i64
210+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 31
211+
// CHECK: @dx.op.wavePrefixOp.i64(i32 121, i64
212+
opaque += RepackingPoint(opaque, 30);
213+
opaque += WavePrefixProduct(commonArg == 17 ? 1 : 0);
214+
opaque += RepackingPoint(opaque, 31);
215+
opaque += WavePrefixProduct(commonArg == 17 ? 1 : 0);
216+
217+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 32
218+
// CHECK: @dx.op.waveAllOp(i32 135, i1
219+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 33
220+
// CHECK: @dx.op.waveAllOp(i32 135, i1
221+
opaque += RepackingPoint(opaque, 32);
222+
opaque += WaveActiveCountBits(commonArg == 17);
223+
opaque += RepackingPoint(opaque, 33);
224+
opaque += WaveActiveCountBits(commonArg == 17);
225+
226+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 34
227+
// CHECK: @dx.op.wavePrefixOp(i32 136, i1
228+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 35
229+
// CHECK: @dx.op.wavePrefixOp(i32 136, i1
230+
opaque += RepackingPoint(opaque, 34);
231+
opaque += WavePrefixCountBits(commonArg == 17);
232+
opaque += RepackingPoint(opaque, 35);
233+
opaque += WavePrefixCountBits(commonArg == 17);
234+
235+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 36
236+
// CHECK: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165, i32
237+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 37
238+
// CHECK: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165, i32
239+
opaque += RepackingPoint(opaque, 36);
240+
uint4 mask = WaveMatch(commonArg);
241+
opaque += mask.x;
242+
opaque += RepackingPoint(opaque, 37);
243+
opaque += WaveMatch(commonArg).x;
244+
245+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 38
246+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
247+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 39
248+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
249+
opaque += RepackingPoint(opaque, 38);
250+
opaque += WaveMultiPrefixBitAnd(commonArg, mask);
251+
opaque += RepackingPoint(opaque, 39);
252+
opaque += WaveMultiPrefixBitAnd(commonArg, mask);
253+
254+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 40
255+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
256+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 41
257+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
258+
opaque += RepackingPoint(opaque, 40);
259+
opaque += WaveMultiPrefixBitOr(commonArg, mask);
260+
opaque += RepackingPoint(opaque, 41);
261+
opaque += WaveMultiPrefixBitOr(commonArg, mask);
262+
263+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 42
264+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
265+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 43
266+
// CHECK: @dx.op.waveMultiPrefixOp.i32(i32 166, i32
267+
opaque += RepackingPoint(opaque, 42);
268+
opaque += WaveMultiPrefixBitXor(commonArg, mask);
269+
opaque += RepackingPoint(opaque, 43);
270+
opaque += WaveMultiPrefixBitXor(commonArg, mask);
271+
272+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 44
273+
// CHECK: @dx.op.waveMultiPrefixBitCount(i32 167, i1
274+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 45
275+
// CHECK: @dx.op.waveMultiPrefixBitCount(i32 167, i1
276+
opaque += RepackingPoint(opaque, 44);
277+
opaque += WaveMultiPrefixCountBits(commonArg == 17, mask);
278+
opaque += RepackingPoint(opaque, 45);
279+
opaque += WaveMultiPrefixCountBits(commonArg == 17, mask);
280+
281+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 46
282+
// CHECK: @dx.op.waveMultiPrefixOp.i64(i32 166, i64
283+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 47
284+
// CHECK: @dx.op.waveMultiPrefixOp.i64(i32 166, i64
285+
opaque += RepackingPoint(opaque, 46);
286+
opaque += WaveMultiPrefixProduct(commonArg == 17 ? 1 : 0, mask);
287+
opaque += RepackingPoint(opaque, 47);
288+
opaque += WaveMultiPrefixProduct(commonArg == 17 ? 1 : 0, mask);
289+
290+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 48
291+
// CHECK: @dx.op.waveMultiPrefixOp.i64(i32 166, i64
292+
// CHECK: @dx.op.{{traceRay|callShader|reportHit}}{{.*}} i32 49
293+
// CHECK: @dx.op.waveMultiPrefixOp.i64(i32 166, i64
294+
opaque += RepackingPoint(opaque, 48);
295+
opaque += WaveMultiPrefixSum(commonArg == 17 ? 1 : 0, mask);
296+
opaque += RepackingPoint(opaque, 49);
297+
opaque += WaveMultiPrefixSum(commonArg == 17 ? 1 : 0, mask);
298+
299+
output[DispatchRaysIndex().x] = opaque;
300+
}

utils/hct/gen_intrin_main.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ $type1 [[rn]] NonUniformResourceIndex(in any<> index) : nonuniform_resource_inde
270270

271271
// Wave intrinsics. Only those that depend on the exec mask are marked as wave-sensitive
272272
bool [[wv]] WaveIsFirstLane();
273-
uint [[rn]] WaveGetLaneIndex();
273+
uint [[ro]] WaveGetLaneIndex();
274274
uint [[rn]] WaveGetLaneCount();
275275
bool [[wv]] WaveActiveAnyTrue(in bool cond);
276276
bool [[wv]] WaveActiveAllTrue(in bool cond);

utils/hct/hctdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2847,7 +2847,7 @@ def UFI(name, **mappings):
28472847
"WaveGetLaneIndex",
28482848
"returns the index of the current lane in the wave",
28492849
"v",
2850-
"rn",
2850+
"ro",
28512851
[db_dxil_param(0, "i32", "", "operation result")],
28522852
)
28532853
next_op_idx += 1

0 commit comments

Comments
 (0)