Skip to content

Commit 2baa055

Browse files
add more tests for wave size 64
1 parent ad956d6 commit 2baa055

File tree

636 files changed

+103881
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

636 files changed

+103881
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#--- source.hlsl
2+
RWStructuredBuffer<uint> _participant_check_sum : register(u1);
3+
4+
[numthreads(64, 1, 1)]
5+
void main(uint3 tid : SV_DispatchThreadID) {
6+
_participant_check_sum[tid.x] = 0;
7+
uint result = 0;
8+
switch ((WaveGetLaneIndex() % 2)) {
9+
case 0: {
10+
if ((WaveGetLaneIndex() == 51)) {
11+
for (uint i0 = 0; (i0 < 3); i0 = (i0 + 1)) {
12+
if (((WaveGetLaneIndex() & 1) == 1)) {
13+
result = (result + WaveActiveMin(result));
14+
uint _participantCount = WaveActiveSum(1);
15+
bool _isCorrect = (_participantCount == 0);
16+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
17+
}
18+
if ((i0 == 2)) {
19+
break;
20+
}
21+
}
22+
if ((WaveGetLaneIndex() == 56)) {
23+
result = (result + WaveActiveMax(result));
24+
uint _participantCount = WaveActiveSum(1);
25+
bool _isCorrect = (_participantCount == 0);
26+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
27+
}
28+
} else {
29+
if ((WaveGetLaneIndex() < 14)) {
30+
result = (result + WaveActiveMin(result));
31+
uint _participantCount = WaveActiveSum(1);
32+
bool _isCorrect = (_participantCount == 7);
33+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
34+
}
35+
for (uint i1 = 0; (i1 < 2); i1 = (i1 + 1)) {
36+
if ((WaveGetLaneIndex() == 41)) {
37+
result = (result + WaveActiveMin((WaveGetLaneIndex() + 4)));
38+
uint _participantCount = WaveActiveSum(1);
39+
bool _isCorrect = (_participantCount == 0);
40+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
41+
}
42+
if ((i1 == 1)) {
43+
continue;
44+
}
45+
}
46+
}
47+
}
48+
case 1: {
49+
if (((WaveGetLaneIndex() % 2) == 0)) {
50+
result = (result + WaveActiveSum(2));
51+
uint _participantCount = WaveActiveSum(1);
52+
bool _isCorrect = (_participantCount == 32);
53+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
54+
}
55+
break;
56+
}
57+
}
58+
switch ((WaveGetLaneIndex() % 2)) {
59+
case 0: {
60+
if ((WaveGetLaneIndex() < 8)) {
61+
result = (result + WaveActiveSum(1));
62+
uint _participantCount = WaveActiveSum(1);
63+
bool _isCorrect = (_participantCount == 4);
64+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
65+
}
66+
break;
67+
}
68+
case 1: {
69+
if (((WaveGetLaneIndex() % 2) == 0)) {
70+
result = (result + WaveActiveSum(2));
71+
uint _participantCount = WaveActiveSum(1);
72+
bool _isCorrect = (_participantCount == 0);
73+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
74+
}
75+
break;
76+
}
77+
}
78+
}
79+
80+
#--- pipeline.yaml
81+
---
82+
Shaders:
83+
- Stage: Compute
84+
Entry: main
85+
DispatchSize: [1, 1, 1] # Single dispatch for 64 threads
86+
Buffers:
87+
- Name: _participant_check_sum
88+
Format: UInt32
89+
Stride: 4
90+
Fill: 0
91+
Size: 64
92+
- Name: expected_participants
93+
Format: UInt32
94+
Stride: 4
95+
Data: [3, 0, 3, 0, 3, 0, 3, 0, 2, 0, 2, 0, 2, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
96+
Results:
97+
- Result: WaveOpValidation
98+
Rule: BufferExact
99+
Actual: _participant_check_sum
100+
Expected: expected_participants
101+
DescriptorSets:
102+
- Resources:
103+
- Name: _participant_check_sum
104+
Kind: RWStructuredBuffer
105+
DirectXBinding:
106+
Register: 1
107+
Space: 0
108+
VulkanBinding:
109+
Binding: 1
110+
...
111+
#--- end
112+
113+
# RUN: split-file %s %t
114+
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
115+
# RUN: %offloader %t/pipeline.yaml %t.o
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
#--- source.hlsl
2+
RWStructuredBuffer<uint> _participant_check_sum : register(u1);
3+
4+
[numthreads(64, 1, 1)]
5+
void main(uint3 tid : SV_DispatchThreadID) {
6+
_participant_check_sum[tid.x] = 0;
7+
uint result = 0;
8+
switch ((WaveGetLaneIndex() % 4)) {
9+
case 0: {
10+
if ((((WaveGetLaneIndex() == 5) || (WaveGetLaneIndex() == 25)) || (WaveGetLaneIndex() == 43))) {
11+
if (((((WaveGetLaneIndex() == 1) || (WaveGetLaneIndex() == 27)) || (WaveGetLaneIndex() == 42)) || (WaveGetLaneIndex() == 60))) {
12+
result = (result + WaveActiveSum(result));
13+
uint _participantCount = WaveActiveSum(1);
14+
bool _isCorrect = (_participantCount == 0);
15+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
16+
}
17+
uint counter0 = 0;
18+
while ((counter0 < 2)) {
19+
counter0 = (counter0 + 1);
20+
if (((WaveGetLaneIndex() < 5) || (WaveGetLaneIndex() >= 49))) {
21+
result = (result + WaveActiveSum((WaveGetLaneIndex() + 3)));
22+
uint _participantCount = WaveActiveSum(1);
23+
bool _isCorrect = (_participantCount == 0);
24+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
25+
}
26+
if ((WaveGetLaneIndex() == 2)) {
27+
if ((WaveGetLaneIndex() == 19)) {
28+
result = (result + WaveActiveMin(WaveGetLaneIndex()));
29+
uint _participantCount = WaveActiveSum(1);
30+
bool _isCorrect = (_participantCount == 0);
31+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
32+
}
33+
if ((WaveGetLaneIndex() == 59)) {
34+
result = (result + WaveActiveSum(result));
35+
uint _participantCount = WaveActiveSum(1);
36+
bool _isCorrect = (_participantCount == 0);
37+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
38+
}
39+
} else {
40+
if (((WaveGetLaneIndex() < 16) || (WaveGetLaneIndex() >= 59))) {
41+
result = (result + WaveActiveSum(8));
42+
uint _participantCount = WaveActiveSum(1);
43+
bool _isCorrect = (_participantCount == 0);
44+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
45+
}
46+
}
47+
if ((counter0 == 1)) {
48+
break;
49+
}
50+
}
51+
if ((((WaveGetLaneIndex() == 4) || (WaveGetLaneIndex() == 41)) || (WaveGetLaneIndex() == 53))) {
52+
result = (result + WaveActiveMin(result));
53+
uint _participantCount = WaveActiveSum(1);
54+
bool _isCorrect = (_participantCount == 0);
55+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
56+
}
57+
} else {
58+
if (((WaveGetLaneIndex() < 5) || (WaveGetLaneIndex() >= 53))) {
59+
result = (result + WaveActiveMin(result));
60+
uint _participantCount = WaveActiveSum(1);
61+
bool _isCorrect = (_participantCount == 4);
62+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
63+
}
64+
for (uint i1 = 0; (i1 < 3); i1 = (i1 + 1)) {
65+
if (((((WaveGetLaneIndex() == 4) || (WaveGetLaneIndex() == 21)) || (WaveGetLaneIndex() == 34)) || (WaveGetLaneIndex() == 54))) {
66+
result = (result + WaveActiveMax(8));
67+
uint _participantCount = WaveActiveSum(1);
68+
bool _isCorrect = (_participantCount == 1);
69+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
70+
}
71+
if ((WaveGetLaneIndex() == 10)) {
72+
if ((WaveGetLaneIndex() == 4)) {
73+
result = (result + WaveActiveMax(result));
74+
uint _participantCount = WaveActiveSum(1);
75+
bool _isCorrect = (_participantCount == 0);
76+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
77+
}
78+
if ((WaveGetLaneIndex() == 42)) {
79+
result = (result + WaveActiveMax(2));
80+
uint _participantCount = WaveActiveSum(1);
81+
bool _isCorrect = (_participantCount == 0);
82+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
83+
}
84+
}
85+
if ((((WaveGetLaneIndex() == 4) || (WaveGetLaneIndex() == 34)) || (WaveGetLaneIndex() == 52))) {
86+
result = (result + WaveActiveMin(result));
87+
uint _participantCount = WaveActiveSum(1);
88+
bool _isCorrect = (_participantCount == 2);
89+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
90+
}
91+
}
92+
if (((WaveGetLaneIndex() < 4) || (WaveGetLaneIndex() >= 56))) {
93+
result = (result + WaveActiveSum((WaveGetLaneIndex() + 5)));
94+
uint _participantCount = WaveActiveSum(1);
95+
bool _isCorrect = (_participantCount == 3);
96+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
97+
}
98+
}
99+
}
100+
case 1: {
101+
switch ((WaveGetLaneIndex() % 3)) {
102+
case 0: {
103+
uint counter2 = 0;
104+
while ((counter2 < 3)) {
105+
counter2 = (counter2 + 1);
106+
if (((((WaveGetLaneIndex() == 8) || (WaveGetLaneIndex() == 28)) || (WaveGetLaneIndex() == 41)) || (WaveGetLaneIndex() == 59))) {
107+
result = (result + WaveActiveSum(result));
108+
uint _participantCount = WaveActiveSum(1);
109+
bool _isCorrect = (_participantCount == 0);
110+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
111+
}
112+
uint counter3 = 0;
113+
while ((counter3 < 3)) {
114+
counter3 = (counter3 + 1);
115+
if ((WaveGetLaneIndex() < 8)) {
116+
result = (result + WaveActiveMax(result));
117+
uint _participantCount = WaveActiveSum(1);
118+
bool _isCorrect = (_participantCount == 1);
119+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
120+
}
121+
if ((counter3 == 2)) {
122+
break;
123+
}
124+
}
125+
}
126+
break;
127+
}
128+
case 1: {
129+
for (uint i4 = 0; (i4 < 3); i4 = (i4 + 1)) {
130+
if ((WaveGetLaneIndex() == 38)) {
131+
result = (result + WaveActiveSum(2));
132+
uint _participantCount = WaveActiveSum(1);
133+
bool _isCorrect = (_participantCount == 0);
134+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
135+
}
136+
if (((WaveGetLaneIndex() < 10) || (WaveGetLaneIndex() >= 42))) {
137+
if (((WaveGetLaneIndex() < 16) || (WaveGetLaneIndex() >= 58))) {
138+
result = (result + WaveActiveMax(result));
139+
uint _participantCount = WaveActiveSum(1);
140+
bool _isCorrect = (_participantCount == 3);
141+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
142+
}
143+
if (((WaveGetLaneIndex() < 2) || (WaveGetLaneIndex() >= 54))) {
144+
result = (result + WaveActiveSum(result));
145+
uint _participantCount = WaveActiveSum(1);
146+
bool _isCorrect = (_participantCount == 2);
147+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
148+
}
149+
} else {
150+
if (((WaveGetLaneIndex() < 8) || (WaveGetLaneIndex() >= 43))) {
151+
result = (result + WaveActiveMax(result));
152+
uint _participantCount = WaveActiveSum(1);
153+
bool _isCorrect = (_participantCount == 0);
154+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
155+
}
156+
}
157+
if ((WaveGetLaneIndex() == 13)) {
158+
result = (result + WaveActiveSum(WaveGetLaneIndex()));
159+
uint _participantCount = WaveActiveSum(1);
160+
bool _isCorrect = (_participantCount == 1);
161+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
162+
}
163+
}
164+
break;
165+
}
166+
case 2: {
167+
if (true) {
168+
result = (result + WaveActiveSum(3));
169+
uint _participantCount = WaveActiveSum(1);
170+
bool _isCorrect = (_participantCount == 10);
171+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
172+
}
173+
break;
174+
}
175+
}
176+
}
177+
case 2: {
178+
if (true) {
179+
result = (result + WaveActiveSum(3));
180+
uint _participantCount = WaveActiveSum(1);
181+
bool _isCorrect = (_participantCount == 48);
182+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
183+
}
184+
}
185+
case 3: {
186+
if ((WaveGetLaneIndex() < 20)) {
187+
result = (result + WaveActiveSum(4));
188+
uint _participantCount = WaveActiveSum(1);
189+
bool _isCorrect = (_participantCount == 20);
190+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
191+
}
192+
break;
193+
}
194+
default: {
195+
result = (result + WaveActiveSum(99));
196+
uint _participantCount = WaveActiveSum(1);
197+
bool _isCorrect = (_participantCount == 0);
198+
_participant_check_sum[tid.x] = (_participant_check_sum[tid.x] + (_isCorrect ? 1 : 0));
199+
break;
200+
}
201+
}
202+
}
203+
204+
#--- pipeline.yaml
205+
---
206+
Shaders:
207+
- Stage: Compute
208+
Entry: main
209+
DispatchSize: [1, 1, 1] # Single dispatch for 64 threads
210+
Buffers:
211+
- Name: _participant_check_sum
212+
Format: UInt32
213+
Stride: 4
214+
Fill: 0
215+
Size: 64
216+
- Name: expected_participants
217+
Format: UInt32
218+
Stride: 4
219+
Data: [10, 8, 2, 1, 12, 3, 2, 1, 3, 2, 2, 1, 2, 5, 2, 1, 2, 3, 2, 1, 2, 1, 1, 0, 1, 1, 1, 0, 1, 2, 1, 0, 2, 1, 1, 0, 1, 1, 1, 0, 1, 2, 1, 0, 2, 1, 1, 0, 1, 1, 1, 0, 4, 2, 1, 0, 4, 1, 1, 0, 3, 7, 1, 0]
220+
Results:
221+
- Result: WaveOpValidation
222+
Rule: BufferExact
223+
Actual: _participant_check_sum
224+
Expected: expected_participants
225+
DescriptorSets:
226+
- Resources:
227+
- Name: _participant_check_sum
228+
Kind: RWStructuredBuffer
229+
DirectXBinding:
230+
Register: 1
231+
Space: 0
232+
VulkanBinding:
233+
Binding: 1
234+
...
235+
#--- end
236+
237+
# RUN: split-file %s %t
238+
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
239+
# RUN: %offloader %t/pipeline.yaml %t.o

0 commit comments

Comments
 (0)