diff --git a/test/Feature/WaveOps/WaveActiveAnyTrue.test b/test/Feature/WaveOps/WaveActiveAnyTrue.test index 07af08c1d..3d8da2472 100644 --- a/test/Feature/WaveOps/WaveActiveAnyTrue.test +++ b/test/Feature/WaveOps/WaveActiveAnyTrue.test @@ -5,15 +5,11 @@ RWStructuredBuffer Out : register(u1); [numthreads(4, 1, 1)] void main(uint3 threadID : SV_DispatchThreadID) { bool B1 = false; - switch (value[threadID.x]) { - case 0: - case 2: // threads 0, 1, 2; result for each false - Out[threadID.x] = WaveActiveAnyTrue(B1); - B1 = true; - break; - default: // thread 3; result is false - Out[threadID.x] = WaveActiveAnyTrue(B1); - break; + if (value[threadID.x] == 0 || value[threadID.x] == 2) { + Out[threadID.x] = WaveActiveAnyTrue(B1); + B1 = true; + } else { + Out[threadID.x] = WaveActiveAnyTrue(B1); } // result for all threads is true because B1 is true for threads 0-2 Out[threadID.x + 4] = WaveActiveAnyTrue(B1); diff --git a/test/Feature/WaveOps/WaveActiveCountBits.test b/test/Feature/WaveOps/WaveActiveCountBits.test index a8e289c92..cd250d95c 100644 --- a/test/Feature/WaveOps/WaveActiveCountBits.test +++ b/test/Feature/WaveOps/WaveActiveCountBits.test @@ -6,17 +6,16 @@ RWStructuredBuffer Out : register(u1); void main(uint3 threadID : SV_DispatchThreadID) { bool B1 = false; - switch (value[threadID.x]) { - case 0: // threads 0 and 1; result is number of active lanes (2) - Out[threadID.x + 4] = WaveActiveCountBits(true); // threads 0 and 1 - case 2: - B1 = true; // set b1 to true for thread 3 - break; - default: - Out[threadID.x + 4] = WaveActiveCountBits(false); // thread 2; expect 0 - break; - } - // should be 3 because B1 set to true for threads 0,1, and 3. + if (value[threadID.x] == 0) // thread 0 and 1. + Out[threadID.x + 4] = WaveActiveCountBits(true); + // thread 0-3 set b1 to true. + if (value[threadID.x] == 0 || value[threadID.x] == 2) + B1 = true; + else // thread 3 + Out[threadID.x + 4] = WaveActiveCountBits(false); + + // Out[threadID.x + 4] on thread 4 is never written to and should remain 0. + // All threads count 3 because B1 set to true for threads 0,1, and 3. uint Count = WaveActiveCountBits(B1); Out[threadID.x] = Count; } @@ -35,11 +34,11 @@ Buffers: - Name: Out Format: UInt32 Stride: 4 - ZeroInitSize: 28 + ZeroInitSize: 32 - Name: ExpectedOut Format: UInt32 Stride: 4 - Data: [3, 3, 3, 3, 2, 2, 0] + Data: [3, 3, 3, 3, 2, 2, 0, 0] Results: - Result: Test Rule: BufferExact