Skip to content

Commit 43ff6cb

Browse files
authored
Add wave active sum tests (#384)
Adds waveactivesum tests. Fixes #161
1 parent 1e4b424 commit 43ff6cb

File tree

6 files changed

+1532
-0
lines changed

6 files changed

+1532
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#--- source.hlsl
2+
StructuredBuffer<half4> In : register(t0);
3+
RWStructuredBuffer<half4> Out1 : register(u1); // test scalar
4+
RWStructuredBuffer<half4> Out2 : register(u2); // test half2
5+
RWStructuredBuffer<half4> Out3 : register(u3); // test half3
6+
RWStructuredBuffer<half4> Out4 : register(u4); // test half4
7+
RWStructuredBuffer<half4> Out5 : register(u5); // constant folding
8+
9+
[numthreads(4,1,1)]
10+
void main(uint3 tid : SV_GroupThreadID)
11+
{
12+
half4 v = In[0];
13+
14+
// Mask per "active lane set": only <=N lanes contribute
15+
half s1 = tid.x <= 0 ? WaveActiveSum( v.x ) : 0;
16+
half s2 = tid.x <= 1 ? WaveActiveSum( v.x ) : 0;
17+
half s3 = tid.x <= 2 ? WaveActiveSum( v.x ) : 0;
18+
half s4 = tid.x <= 3 ? WaveActiveSum( v.x ) : 0;
19+
20+
half2 v2_1 = tid.x <= 0 ? WaveActiveSum( v.xy ) : half2(0,0);
21+
half2 v2_2 = tid.x <= 1 ? WaveActiveSum( v.xy ) : half2(0,0);
22+
half2 v2_3 = tid.x <= 2 ? WaveActiveSum( v.xy ) : half2(0,0);
23+
half2 v2_4 = tid.x <= 3 ? WaveActiveSum( v.xy ) : half2(0,0);
24+
25+
half3 v3_1 = tid.x <= 0 ? WaveActiveSum( v.xyz ) : half3(0,0,0);
26+
half3 v3_2 = tid.x <= 1 ? WaveActiveSum( v.xyz ) : half3(0,0,0);
27+
half3 v3_3 = tid.x <= 2 ? WaveActiveSum( v.xyz ) : half3(0,0,0);
28+
half3 v3_4 = tid.x <= 3 ? WaveActiveSum( v.xyz ) : half3(0,0,0);
29+
30+
half4 v4_1 = tid.x <= 0 ? WaveActiveSum( v ) : half4(0,0,0,0);
31+
half4 v4_2 = tid.x <= 1 ? WaveActiveSum( v ) : half4(0,0,0,0);
32+
half4 v4_3 = tid.x <= 2 ? WaveActiveSum( v ) : half4(0,0,0,0);
33+
half4 v4_4 = tid.x <= 3 ? WaveActiveSum( v ) : half4(0,0,0,0);
34+
35+
half scalars[4] = { s1, s2, s3, s4 };
36+
half2 vec2s [4] = { v2_1, v2_2, v2_3, v2_4 };
37+
half3 vec3s [4] = { v3_1, v3_2, v3_3, v3_4 };
38+
half4 vec4s [4] = { v4_1, v4_2, v4_3, v4_4 };
39+
40+
41+
Out1[tid.x].x = scalars[tid.x];
42+
Out2[tid.x].xy = vec2s[tid.x];
43+
Out3[tid.x].xyz = vec3s[tid.x];
44+
Out4[tid.x] = vec4s[tid.x];
45+
46+
// constant folding case
47+
Out5[0] = WaveActiveSum(half4(1,2,3,4));
48+
}
49+
50+
//--- pipeline.yaml
51+
52+
---
53+
Shaders:
54+
- Stage: Compute
55+
Entry: main
56+
DispatchSize: [1, 1, 1]
57+
Buffers:
58+
- Name: In
59+
Format: Float16
60+
Stride: 2
61+
Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0]
62+
- Name: Out1
63+
Format: Float16
64+
Stride: 8
65+
ZeroInitSize: 32
66+
- Name: Out2
67+
Format: Float16
68+
Stride: 8
69+
ZeroInitSize: 32
70+
- Name: Out3
71+
Format: Float16
72+
Stride: 8
73+
ZeroInitSize: 32
74+
- Name: Out4
75+
Format: Float16
76+
Stride: 8
77+
ZeroInitSize: 32
78+
- Name: Out5
79+
Format: Float16
80+
Stride: 8
81+
ZeroInitSize: 8
82+
- Name: ExpectedOut1
83+
Format: Float16
84+
Stride: 8
85+
Data: [ 0x3c00, 0x0, 0x0, 0x0, 0x4000, 0x0, 0x0, 0x0, 0x4200, 0x0, 0x0, 0x0, 0x4400, 0x0, 0x0, 0x0 ]
86+
- Name: ExpectedOut2
87+
Format: Float16
88+
Stride: 8
89+
Data: [ 0x3c00, 0x4900, 0x0, 0x0, 0x4000, 0x4d00, 0x0, 0x0, 0x4200, 0x4f80, 0x0, 0x0, 0x4400, 0x5100, 0x0, 0x0 ]
90+
- Name: ExpectedOut3
91+
Format: Float16
92+
Stride: 8
93+
Data: [ 0x3c00, 0x4900, 0x5640, 0x0, 0x4000, 0x4d00, 0x5a40, 0x0, 0x4200, 0x4f80, 0x5cb0, 0x0, 0x4400, 0x5100, 0x5e40, 0x0 ]
94+
- Name: ExpectedOut4
95+
Format: Float16
96+
Stride: 8
97+
Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0, 0x4000, 0x4d00, 0x5a40, 0x67d0, 0x4200, 0x4f80, 0x5cb0, 0x69dc, 0x4400, 0x5100, 0x5e40, 0x6bd0 ]
98+
- Name: ExpectedOut5
99+
Format: Float16
100+
Stride: 8
101+
Data: [ 0x4400, 0x4800, 0x4a00, 0x4c00 ]
102+
Results:
103+
- Result: ExpectedOut1
104+
Rule: BufferExact
105+
Actual: Out1
106+
Expected: ExpectedOut1
107+
- Result: ExpectedOut2
108+
Rule: BufferExact
109+
Actual: Out2
110+
Expected: ExpectedOut2
111+
- Result: ExpectedOut3
112+
Rule: BufferExact
113+
Actual: Out3
114+
Expected: ExpectedOut3
115+
- Result: ExpectedOut4
116+
Rule: BufferExact
117+
Actual: Out4
118+
Expected: ExpectedOut4
119+
- Result: ExpectedOut5
120+
Rule: BufferExact
121+
Actual: Out5
122+
Expected: ExpectedOut5
123+
DescriptorSets:
124+
- Resources:
125+
- Name: In
126+
Kind: StructuredBuffer
127+
DirectXBinding:
128+
Register: 0
129+
Space: 0
130+
VulkanBinding:
131+
Binding: 0
132+
- Name: Out1
133+
Kind: RWStructuredBuffer
134+
DirectXBinding:
135+
Register: 1
136+
Space: 0
137+
VulkanBinding:
138+
Binding: 1
139+
- Name: Out2
140+
Kind: RWStructuredBuffer
141+
DirectXBinding:
142+
Register: 2
143+
Space: 0
144+
VulkanBinding:
145+
Binding: 2
146+
- Name: Out3
147+
Kind: RWStructuredBuffer
148+
DirectXBinding:
149+
Register: 3
150+
Space: 0
151+
VulkanBinding:
152+
Binding: 3
153+
- Name: Out4
154+
Kind: RWStructuredBuffer
155+
DirectXBinding:
156+
Register: 4
157+
Space: 0
158+
VulkanBinding:
159+
Binding: 4
160+
- Name: Out5
161+
Kind: RWStructuredBuffer
162+
DirectXBinding:
163+
Register: 5
164+
Space: 0
165+
VulkanBinding:
166+
Binding: 5
167+
168+
...
169+
#--- end
170+
171+
# Bug https://github.com/llvm/llvm-project/issues/156775
172+
# XFAIL: Clang
173+
174+
# Bug https://github.com/llvm/offload-test-suite/issues/393
175+
# XFAIL: Metal
176+
177+
# REQUIRES: Half
178+
179+
# RUN: split-file %s %t
180+
# RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl
181+
# RUN: %offloader %t/pipeline.yaml %t.o
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#--- source.hlsl
2+
StructuredBuffer<float4> In : register(t0);
3+
RWStructuredBuffer<float4> Out1 : register(u1); // test scalar
4+
RWStructuredBuffer<float4> Out2 : register(u2); // test float2
5+
RWStructuredBuffer<float4> Out3 : register(u3); // test float3
6+
RWStructuredBuffer<float4> Out4 : register(u4); // test float4
7+
RWStructuredBuffer<float4> Out5 : register(u5); // constant folding
8+
9+
[numthreads(4,1,1)]
10+
void main(uint3 tid : SV_GroupThreadID)
11+
{
12+
float4 v = In[0];
13+
14+
// Mask per "active lane set": only <=N lanes contribute
15+
float s1 = tid.x <= 0 ? WaveActiveSum( v.x ) : 0;
16+
float s2 = tid.x <= 1 ? WaveActiveSum( v.x ) : 0;
17+
float s3 = tid.x <= 2 ? WaveActiveSum( v.x ) : 0;
18+
float s4 = tid.x <= 3 ? WaveActiveSum( v.x ) : 0;
19+
20+
float2 v2_1 = tid.x <= 0 ? WaveActiveSum( v.xy ) : float2(0,0);
21+
float2 v2_2 = tid.x <= 1 ? WaveActiveSum( v.xy ) : float2(0,0);
22+
float2 v2_3 = tid.x <= 2 ? WaveActiveSum( v.xy ) : float2(0,0);
23+
float2 v2_4 = tid.x <= 3 ? WaveActiveSum( v.xy ) : float2(0,0);
24+
25+
float3 v3_1 = tid.x <= 0 ? WaveActiveSum( v.xyz ) : float3(0,0,0);
26+
float3 v3_2 = tid.x <= 1 ? WaveActiveSum( v.xyz ) : float3(0,0,0);
27+
float3 v3_3 = tid.x <= 2 ? WaveActiveSum( v.xyz ) : float3(0,0,0);
28+
float3 v3_4 = tid.x <= 3 ? WaveActiveSum( v.xyz ) : float3(0,0,0);
29+
30+
float4 v4_1 = tid.x <= 0 ? WaveActiveSum( v ) : float4(0,0,0,0);
31+
float4 v4_2 = tid.x <= 1 ? WaveActiveSum( v ) : float4(0,0,0,0);
32+
float4 v4_3 = tid.x <= 2 ? WaveActiveSum( v ) : float4(0,0,0,0);
33+
float4 v4_4 = tid.x <= 3 ? WaveActiveSum( v ) : float4(0,0,0,0);
34+
35+
float scalars[4] = { s1, s2, s3, s4 };
36+
float2 vec2s [4] = { v2_1, v2_2, v2_3, v2_4 };
37+
float3 vec3s [4] = { v3_1, v3_2, v3_3, v3_4 };
38+
float4 vec4s [4] = { v4_1, v4_2, v4_3, v4_4 };
39+
40+
41+
Out1[tid.x].x = scalars[tid.x];
42+
Out2[tid.x].xy = vec2s[tid.x];
43+
Out3[tid.x].xyz = vec3s[tid.x];
44+
Out4[tid.x] = vec4s[tid.x];
45+
46+
// constant folding case
47+
Out5[0] = WaveActiveSum(float4(1,2,3,4));
48+
}
49+
50+
//--- pipeline.yaml
51+
52+
---
53+
Shaders:
54+
- Stage: Compute
55+
Entry: main
56+
DispatchSize: [1, 1, 1]
57+
Buffers:
58+
- Name: In
59+
Format: Float32
60+
Stride: 4
61+
Data: [ 1.0, 10.0, 100.0, 1000.0]
62+
- Name: Out1
63+
Format: Float32
64+
Stride: 16
65+
ZeroInitSize: 64
66+
- Name: Out2
67+
Format: Float32
68+
Stride: 16
69+
ZeroInitSize: 64
70+
- Name: Out3
71+
Format: Float32
72+
Stride: 16
73+
ZeroInitSize: 64
74+
- Name: Out4
75+
Format: Float32
76+
Stride: 16
77+
ZeroInitSize: 64
78+
- Name: Out5
79+
Format: Float32
80+
Stride: 16
81+
ZeroInitSize: 16
82+
- Name: ExpectedOut1
83+
Format: Float32
84+
Stride: 16
85+
Data: [ 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0 ]
86+
- Name: ExpectedOut2
87+
Format: Float32
88+
Stride: 16
89+
Data: [ 1.0, 10.0, 0.0, 0.0, 2.0, 20.0, 0.0, 0.0, 3.0, 30.0, 0.0, 0.0, 4.0, 40.0, 0.0, 0.0 ]
90+
- Name: ExpectedOut3
91+
Format: Float32
92+
Stride: 16
93+
Data: [ 1.0, 10.0, 100.0, 0.0, 2.0, 20.0, 200.0, 0.0, 3.0, 30.0, 300.0, 0.0, 4.0, 40.0, 400.0, 0.0 ]
94+
- Name: ExpectedOut4
95+
Format: Float32
96+
Stride: 16
97+
Data: [ 1.0, 10.0, 100.0, 1000.0, 2.0, 20.0, 200.0, 2000.0, 3.0, 30.0, 300.0, 3000.0, 4.0, 40.0, 400.0, 4000.0 ]
98+
- Name: ExpectedOut5
99+
Format: Float32
100+
Stride: 16
101+
Data: [ 4, 8, 12, 16 ]
102+
Results:
103+
- Result: ExpectedOut1
104+
Rule: BufferExact
105+
Actual: Out1
106+
Expected: ExpectedOut1
107+
- Result: ExpectedOut2
108+
Rule: BufferExact
109+
Actual: Out2
110+
Expected: ExpectedOut2
111+
- Result: ExpectedOut3
112+
Rule: BufferExact
113+
Actual: Out3
114+
Expected: ExpectedOut3
115+
- Result: ExpectedOut4
116+
Rule: BufferExact
117+
Actual: Out4
118+
Expected: ExpectedOut4
119+
- Result: ExpectedOut5
120+
Rule: BufferExact
121+
Actual: Out5
122+
Expected: ExpectedOut5
123+
DescriptorSets:
124+
- Resources:
125+
- Name: In
126+
Kind: StructuredBuffer
127+
DirectXBinding:
128+
Register: 0
129+
Space: 0
130+
VulkanBinding:
131+
Binding: 0
132+
- Name: Out1
133+
Kind: RWStructuredBuffer
134+
DirectXBinding:
135+
Register: 1
136+
Space: 0
137+
VulkanBinding:
138+
Binding: 1
139+
- Name: Out2
140+
Kind: RWStructuredBuffer
141+
DirectXBinding:
142+
Register: 2
143+
Space: 0
144+
VulkanBinding:
145+
Binding: 2
146+
- Name: Out3
147+
Kind: RWStructuredBuffer
148+
DirectXBinding:
149+
Register: 3
150+
Space: 0
151+
VulkanBinding:
152+
Binding: 3
153+
- Name: Out4
154+
Kind: RWStructuredBuffer
155+
DirectXBinding:
156+
Register: 4
157+
Space: 0
158+
VulkanBinding:
159+
Binding: 4
160+
- Name: Out5
161+
Kind: RWStructuredBuffer
162+
DirectXBinding:
163+
Register: 5
164+
Space: 0
165+
VulkanBinding:
166+
Binding: 5
167+
168+
...
169+
#--- end
170+
171+
# Bug https://github.com/llvm/llvm-project/issues/156775
172+
# XFAIL: Clang
173+
174+
# RUN: split-file %s %t
175+
# RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl
176+
# RUN: %offloader %t/pipeline.yaml %t.o

0 commit comments

Comments
 (0)