Skip to content

Commit f528875

Browse files
authored
Add select (basic types) tests (#388)
Closes #132. Adds tests for `select` (basic types) testing 16 bit int types, half, 32 bit types, 64 bit int types, and double. For the `Cond` buffers (and `TrueVal`, `FalseVal` for the bool test), I had to use `bool` instead of `bool4` because select gives an "Invalid overload type" error if it receives a vec1 bool instead of a scalar as input.
1 parent d7f6470 commit f528875

File tree

5 files changed

+891
-0
lines changed

5 files changed

+891
-0
lines changed
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
#--- source.hlsl
2+
3+
// This test tests all the following scenarios for select:
4+
// - Scalar condition, scalar true/false values
5+
// - Vector condition, vector true/false values
6+
// - Vector condition, scalar true value, vector false value
7+
// - Vector condition, vector true value, scalar false value
8+
// - Vector condition, scalar true/false values
9+
// For each vector condition scenario, there are tests for vec4, vec3, and vec2.
10+
// For the scalar condition scenario, there are four tests. One uses the buffers
11+
// for inputs and the other three use constants.
12+
13+
StructuredBuffer<bool> Cond : register(t0);
14+
StructuredBuffer<float4> TrueVal0 : register(t1);
15+
StructuredBuffer<float4> FalseVal0 : register(t2);
16+
StructuredBuffer<int4> TrueVal1 : register(t3);
17+
StructuredBuffer<int4> FalseVal1 : register(t4);
18+
StructuredBuffer<uint4> TrueVal2 : register(t5);
19+
StructuredBuffer<uint4> FalseVal2 : register(t6);
20+
StructuredBuffer<bool> TrueVal3 : register(t7);
21+
StructuredBuffer<bool> FalseVal3 : register(t8);
22+
23+
RWStructuredBuffer<float4> Out0 : register(u9);
24+
RWStructuredBuffer<int4> Out1 : register(u10);
25+
RWStructuredBuffer<uint4> Out2 : register(u11);
26+
RWStructuredBuffer<bool4> Out3 : register(u12);
27+
28+
29+
[numthreads(1,1,1)]
30+
void main() {
31+
bool4 Cond0 = bool4(Cond[0], Cond[1], Cond[2], Cond[3]);
32+
bool3 Cond1 = bool3(Cond[4], Cond[5], Cond[6]);
33+
bool2 Cond2 = bool2(Cond[8], Cond[9]);
34+
bool2 Cond3 = bool2(Cond[10], Cond[11]);
35+
36+
// float
37+
// vec4
38+
Out0[0] = select(Cond0, TrueVal0[0], FalseVal0[0]);
39+
Out0[1] = select(Cond0, TrueVal0[0].x, FalseVal0[0]);
40+
Out0[2] = select(Cond0, TrueVal0[0], FalseVal0[0].x);
41+
Out0[3] = select(Cond0, TrueVal0[0].x, FalseVal0[0].x);
42+
// vec3 + scalar
43+
Out0[4] = float4(select(Cond1, TrueVal0[1].xyz, FalseVal0[1].xyz), select(Cond[7], TrueVal0[1].w, FalseVal0[1].w));
44+
Out0[5] = float4(select(Cond1, TrueVal0[1].x, FalseVal0[1].xyz), select(bool(1), float(1), float(-1)));
45+
Out0[6] = float4(select(Cond1, TrueVal0[1].xyz, FalseVal0[1].x), select(bool(0), float(2), float(-2)));
46+
Out0[7] = float4(select(Cond1, TrueVal0[1].x, FalseVal0[1].x), select(bool(1), float(3), float(-3)));
47+
// vec2
48+
Out0[8] = float4(select(Cond2, TrueVal0[2].xy, FalseVal0[2].xy), select(Cond3, TrueVal0[2].z, FalseVal0[2].zw));
49+
Out0[9] = float4(select(Cond2, TrueVal0[2].xy, FalseVal0[2].x), select(Cond3, TrueVal0[2].z, FalseVal0[2].z));
50+
51+
// int
52+
// vec4
53+
Out1[0] = select(Cond0, TrueVal1[0], FalseVal1[0]);
54+
Out1[1] = select(Cond0, TrueVal1[0].x, FalseVal1[0]);
55+
Out1[2] = select(Cond0, TrueVal1[0], FalseVal1[0].x);
56+
Out1[3] = select(Cond0, TrueVal1[0].x, FalseVal1[0].x);
57+
// vec3 + scalar
58+
Out1[4] = int4(select(Cond1, TrueVal1[1].xyz, FalseVal1[1].xyz), select(Cond[7], TrueVal1[1].w, FalseVal1[1].w));
59+
Out1[5] = int4(select(Cond1, TrueVal1[1].x, FalseVal1[1].xyz), select(bool(1), int(1), int(-1)));
60+
Out1[6] = int4(select(Cond1, TrueVal1[1].xyz, FalseVal1[1].x), select(bool(0), int(2), int(-2)));
61+
Out1[7] = int4(select(Cond1, TrueVal1[1].x, FalseVal1[1].x), select(bool(1), int(3), int(-3)));
62+
// vec2
63+
Out1[8] = int4(select(Cond2, TrueVal1[2].xy, FalseVal1[2].xy), select(Cond3, TrueVal1[2].z, FalseVal1[2].zw));
64+
Out1[9] = int4(select(Cond2, TrueVal1[2].xy, FalseVal1[2].x), select(Cond3, TrueVal1[2].z, FalseVal1[2].z));
65+
66+
// uint
67+
// vec4
68+
Out2[0] = select(Cond0, TrueVal2[0], FalseVal2[0]);
69+
Out2[1] = select(Cond0, TrueVal2[0].x, FalseVal2[0]);
70+
Out2[2] = select(Cond0, TrueVal2[0], FalseVal2[0].x);
71+
Out2[3] = select(Cond0, TrueVal2[0].x, FalseVal2[0].x);
72+
// vec3 + scalar
73+
Out2[4] = uint4(select(Cond1, TrueVal2[1].xyz, FalseVal2[1].xyz), select(Cond[7], TrueVal2[1].w, FalseVal2[1].w));
74+
Out2[5] = uint4(select(Cond1, TrueVal2[1].x, FalseVal2[1].xyz), select(bool(1), uint(1), uint(10)));
75+
Out2[6] = uint4(select(Cond1, TrueVal2[1].xyz, FalseVal2[1].x), select(bool(0), uint(2), uint(20)));
76+
Out2[7] = uint4(select(Cond1, TrueVal2[1].x, FalseVal2[1].x), select(bool(1), uint(3), uint(30)));
77+
// vec2
78+
Out2[8] = uint4(select(Cond2, TrueVal2[2].xy, FalseVal2[2].xy), select(Cond3, TrueVal2[2].z, FalseVal2[2].zw));
79+
Out2[9] = uint4(select(Cond2, TrueVal2[2].xy, FalseVal2[2].x), select(Cond3, TrueVal2[2].z, FalseVal2[2].z));
80+
81+
// bool
82+
// vec4
83+
bool4 TrueVal3Tmp0 = bool4(TrueVal3[0], TrueVal3[1], TrueVal3[2], TrueVal3[3]);
84+
bool4 FalseVal3Tmp0 = bool4(FalseVal3[0], FalseVal3[1], FalseVal3[2], FalseVal3[3]);
85+
Out3[0] = select(Cond0, TrueVal3Tmp0, FalseVal3Tmp0);
86+
Out3[1] = select(Cond0, TrueVal3[0], FalseVal3Tmp0);
87+
Out3[2] = select(Cond0, TrueVal3Tmp0, FalseVal3[0]);
88+
Out3[3] = select(Cond0, TrueVal3[0], FalseVal3[0]);
89+
// vec3 + scalar
90+
bool3 TrueVal3Tmp1 = bool3(TrueVal3[4], TrueVal3[5], TrueVal3[6]);
91+
bool3 FalseVal3Tmp1 = bool3(FalseVal3[4], FalseVal3[5], FalseVal3[6]);
92+
Out3[4] = bool4(select(Cond1, TrueVal3Tmp1, FalseVal3Tmp1), select(Cond[7], TrueVal3[7], FalseVal3[7]));
93+
Out3[5] = bool4(select(Cond1, TrueVal3[4], FalseVal3Tmp1), select(bool(1), bool(1), bool(0)));
94+
Out3[6] = bool4(select(Cond1, TrueVal3Tmp1, FalseVal3[4]), select(bool(0), bool(1), bool(0)));
95+
Out3[7] = bool4(select(Cond1, TrueVal3[4], FalseVal3[4]), select(bool(1), bool(1), bool(0)));
96+
// vec2
97+
Out3[8] = bool4(select(Cond2, bool2(TrueVal3[8], TrueVal3[9]), bool2(FalseVal3[8], FalseVal3[9])), select(Cond3, TrueVal3[10], bool2(FalseVal3[10], FalseVal3[11])));
98+
Out3[9] = bool4(select(Cond2, bool2(TrueVal3[8], TrueVal3[9]), FalseVal3[8]), select(Cond3, TrueVal3[10], FalseVal3[10]));
99+
}
100+
//--- pipeline.yaml
101+
102+
---
103+
Shaders:
104+
- Stage: Compute
105+
Entry: main
106+
DispatchSize: [1, 1, 1]
107+
Buffers:
108+
- Name: Cond
109+
Format: Bool
110+
Stride: 4
111+
Data: [ 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1 ]
112+
- Name: TrueVal0
113+
Format: Float32
114+
Stride: 16
115+
Data: [ 1, 2, 3, 4, 4.4, -5.5, 6.6, 3.1415, -10, -20, 15, -25 ]
116+
- Name: FalseVal0
117+
Format: Float32
118+
Stride: 16
119+
Data: [ -1, -2, -3, -4, 7.7, 8.8, -9.9, 0.01, 100, 200, -15, 25 ]
120+
- Name: TrueVal1
121+
Format: Int32
122+
Stride: 16
123+
Data: [ 1, 2, 3, 4, 4, -5, 6, 9, -10, -20, 15, -25 ]
124+
- Name: FalseVal1
125+
Format: Int32
126+
Stride: 16
127+
Data: [ -1, -2, -3, -4, 7, 8, -9, 1, 100, 200, -15, 25 ]
128+
- Name: TrueVal2
129+
Format: UInt32
130+
Stride: 16
131+
Data: [ 1, 2, 3, 4, 4, 5, 6, 9, 10, 20, 15, 250 ]
132+
- Name: FalseVal2
133+
Format: UInt32
134+
Stride: 16
135+
Data: [ 10, 20, 30, 40, 7, 8, 9, 1, 100, 200, 150, 25 ]
136+
- Name: TrueVal3
137+
Format: Bool
138+
Stride: 4
139+
Data: [ 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0 ]
140+
- Name: FalseVal3
141+
Format: Bool
142+
Stride: 4
143+
Data: [ 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1 ]
144+
- Name: Out0
145+
Format: Float32
146+
Stride: 16
147+
ZeroInitSize: 160
148+
- Name: ExpectedOut0
149+
Format: Float32
150+
Stride: 16
151+
Data: [
152+
1, -2, 3, -4, 1, -2, 1, -4, 1, -1, 3, -1, 1, -1, 1, -1,
153+
4.4, -5.5, -9.9, 0.01, 4.4, 4.4, -9.9, 1, 4.4, -5.5, 7.7, -2, 4.4, 4.4, 7.7, 3,
154+
-10, 200, -15, 15, -10, 100, -15, 15
155+
]
156+
- Name: Out1
157+
Format: Int32
158+
Stride: 16
159+
ZeroInitSize: 160
160+
- Name: ExpectedOut1
161+
Format: Int32
162+
Stride: 16
163+
Data: [
164+
1, -2, 3, -4, 1, -2, 1, -4, 1, -1, 3, -1, 1, -1, 1, -1,
165+
4, -5, -9, 1, 4, 4, -9, 1, 4, -5, 7, -2, 4, 4, 7, 3,
166+
-10, 200, -15, 15, -10, 100, -15, 15
167+
]
168+
- Name: Out2
169+
Format: UInt32
170+
Stride: 16
171+
ZeroInitSize: 160
172+
- Name: ExpectedOut2
173+
Format: UInt32
174+
Stride: 16
175+
Data: [
176+
1, 20, 3, 40, 1, 20, 1, 40, 1, 10, 3, 10, 1, 10, 1, 10,
177+
4, 5, 9, 1, 4, 4, 9, 1, 4, 5, 7, 20, 4, 4, 7, 3,
178+
10, 200, 150, 15, 10, 100, 150, 15
179+
]
180+
- Name: Out3
181+
Format: Bool
182+
Stride: 16
183+
ZeroInitSize: 160
184+
- Name: ExpectedOut3
185+
Format: Bool
186+
Stride: 16
187+
Data: [
188+
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
189+
1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1,
190+
0, 0, 1, 1, 0, 1, 1, 1
191+
]
192+
Results:
193+
- Result: Test0
194+
Rule: BufferExact
195+
Actual: Out0
196+
Expected: ExpectedOut0
197+
- Result: Test1
198+
Rule: BufferExact
199+
Actual: Out1
200+
Expected: ExpectedOut1
201+
- Result: Test2
202+
Rule: BufferExact
203+
Actual: Out2
204+
Expected: ExpectedOut2
205+
- Result: Test3
206+
Rule: BufferExact
207+
Actual: Out3
208+
Expected: ExpectedOut3
209+
DescriptorSets:
210+
- Resources:
211+
- Name: Cond
212+
Kind: StructuredBuffer
213+
DirectXBinding:
214+
Register: 0
215+
Space: 0
216+
VulkanBinding:
217+
Binding: 0
218+
- Name: TrueVal0
219+
Kind: StructuredBuffer
220+
DirectXBinding:
221+
Register: 1
222+
Space: 0
223+
VulkanBinding:
224+
Binding: 1
225+
- Name: FalseVal0
226+
Kind: StructuredBuffer
227+
DirectXBinding:
228+
Register: 2
229+
Space: 0
230+
VulkanBinding:
231+
Binding: 2
232+
- Name: TrueVal1
233+
Kind: StructuredBuffer
234+
DirectXBinding:
235+
Register: 3
236+
Space: 0
237+
VulkanBinding:
238+
Binding: 3
239+
- Name: FalseVal1
240+
Kind: StructuredBuffer
241+
DirectXBinding:
242+
Register: 4
243+
Space: 0
244+
VulkanBinding:
245+
Binding: 4
246+
- Name: TrueVal2
247+
Kind: StructuredBuffer
248+
DirectXBinding:
249+
Register: 5
250+
Space: 0
251+
VulkanBinding:
252+
Binding: 5
253+
- Name: FalseVal2
254+
Kind: StructuredBuffer
255+
DirectXBinding:
256+
Register: 6
257+
Space: 0
258+
VulkanBinding:
259+
Binding: 6
260+
- Name: TrueVal3
261+
Kind: StructuredBuffer
262+
DirectXBinding:
263+
Register: 7
264+
Space: 0
265+
VulkanBinding:
266+
Binding: 7
267+
- Name: FalseVal3
268+
Kind: StructuredBuffer
269+
DirectXBinding:
270+
Register: 8
271+
Space: 0
272+
VulkanBinding:
273+
Binding: 8
274+
- Name: Out0
275+
Kind: RWStructuredBuffer
276+
DirectXBinding:
277+
Register: 9
278+
Space: 0
279+
VulkanBinding:
280+
Binding: 9
281+
- Name: Out1
282+
Kind: RWStructuredBuffer
283+
DirectXBinding:
284+
Register: 10
285+
Space: 0
286+
VulkanBinding:
287+
Binding: 10
288+
- Name: Out2
289+
Kind: RWStructuredBuffer
290+
DirectXBinding:
291+
Register: 11
292+
Space: 0
293+
VulkanBinding:
294+
Binding: 11
295+
- Name: Out3
296+
Kind: RWStructuredBuffer
297+
DirectXBinding:
298+
Register: 12
299+
Space: 0
300+
VulkanBinding:
301+
Binding: 12
302+
#--- end
303+
304+
# RUN: split-file %s %t
305+
# RUN: %dxc_target -HV 202x -T cs_6_5 -Fo %t.o %t/source.hlsl
306+
# RUN: %offloader %t/pipeline.yaml %t.o

0 commit comments

Comments
 (0)