From f624aa8f5796aa370ee22ef86e26e5d59941419d Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 2 Sep 2025 15:39:24 -0500 Subject: [PATCH 1/3] Don't use switch fallthrough In SPIRV even trivial switch fallthrough is undefined behavior, so we shouldn't rely on it in wave operations tests. --- test/Feature/WaveOps/WaveActiveAnyTrue.test | 14 +++++--------- test/Feature/WaveOps/WaveActiveCountBits.test | 16 ++++++---------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/test/Feature/WaveOps/WaveActiveAnyTrue.test b/test/Feature/WaveOps/WaveActiveAnyTrue.test index 07af08c1d..3d8da2472 100644 --- a/test/Feature/WaveOps/WaveActiveAnyTrue.test +++ b/test/Feature/WaveOps/WaveActiveAnyTrue.test @@ -5,15 +5,11 @@ RWStructuredBuffer Out : register(u1); [numthreads(4, 1, 1)] void main(uint3 threadID : SV_DispatchThreadID) { bool B1 = false; - switch (value[threadID.x]) { - case 0: - case 2: // threads 0, 1, 2; result for each false - Out[threadID.x] = WaveActiveAnyTrue(B1); - B1 = true; - break; - default: // thread 3; result is false - Out[threadID.x] = WaveActiveAnyTrue(B1); - break; + if (value[threadID.x] == 0 || value[threadID.x] == 2) { + Out[threadID.x] = WaveActiveAnyTrue(B1); + B1 = true; + } else { + Out[threadID.x] = WaveActiveAnyTrue(B1); } // result for all threads is true because B1 is true for threads 0-2 Out[threadID.x + 4] = WaveActiveAnyTrue(B1); diff --git a/test/Feature/WaveOps/WaveActiveCountBits.test b/test/Feature/WaveOps/WaveActiveCountBits.test index a8e289c92..78021a3fd 100644 --- a/test/Feature/WaveOps/WaveActiveCountBits.test +++ b/test/Feature/WaveOps/WaveActiveCountBits.test @@ -6,16 +6,12 @@ RWStructuredBuffer Out : register(u1); void main(uint3 threadID : SV_DispatchThreadID) { bool B1 = false; - switch (value[threadID.x]) { - case 0: // threads 0 and 1; result is number of active lanes (2) - Out[threadID.x + 4] = WaveActiveCountBits(true); // threads 0 and 1 - case 2: - B1 = true; // set b1 to true for thread 3 - break; - default: - Out[threadID.x + 4] = WaveActiveCountBits(false); // thread 2; expect 0 - break; - } + if (value[threadID.x] == 0) + Out[threadID.x + 4] = WaveActiveCountBits(true); // threads 0 and 1 + if (value[threadID.x] == 0 || value[threadID.x] == 1) + B1 = true; // set b1 to true for thread 3 + else + Out[threadID.x + 4] = WaveActiveCountBits(false); // thread 2; expect 0 // should be 3 because B1 set to true for threads 0,1, and 3. uint Count = WaveActiveCountBits(B1); Out[threadID.x] = Count; From 706b13ed166edddb0ce3a5dff747068e949002a5 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 2 Sep 2025 17:53:06 -0500 Subject: [PATCH 2/3] Fix comments and don't truncate buffers --- test/Feature/WaveOps/WaveActiveCountBits.test | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/Feature/WaveOps/WaveActiveCountBits.test b/test/Feature/WaveOps/WaveActiveCountBits.test index 78021a3fd..231c56a96 100644 --- a/test/Feature/WaveOps/WaveActiveCountBits.test +++ b/test/Feature/WaveOps/WaveActiveCountBits.test @@ -6,13 +6,16 @@ RWStructuredBuffer Out : register(u1); void main(uint3 threadID : SV_DispatchThreadID) { bool B1 = false; - if (value[threadID.x] == 0) - Out[threadID.x + 4] = WaveActiveCountBits(true); // threads 0 and 1 + if (value[threadID.x] == 0) // thread 0 and 1. + Out[threadID.x + 4] = WaveActiveCountBits(true); + // thread 0-3 set b1 to true. if (value[threadID.x] == 0 || value[threadID.x] == 1) - B1 = true; // set b1 to true for thread 3 - else - Out[threadID.x + 4] = WaveActiveCountBits(false); // thread 2; expect 0 - // should be 3 because B1 set to true for threads 0,1, and 3. + B1 = true; + else // thread 4 + Out[threadID.x + 4] = WaveActiveCountBits(false); + + // Out[threadID.x + 4] on thread 3 is never written to and should remain 0. + // All threads count 3 because B1 set to true for threads 0,1, and 3. uint Count = WaveActiveCountBits(B1); Out[threadID.x] = Count; } @@ -31,11 +34,11 @@ Buffers: - Name: Out Format: UInt32 Stride: 4 - ZeroInitSize: 28 + ZeroInitSize: 32 - Name: ExpectedOut Format: UInt32 Stride: 4 - Data: [3, 3, 3, 3, 2, 2, 0] + Data: [3, 3, 3, 3, 2, 2, 0, 0] Results: - Result: Test Rule: BufferExact From 1a4c00ff2bf2662cf4111fcdf58f4236deb1d30a Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 2 Sep 2025 19:49:10 -0500 Subject: [PATCH 3/3] Bleh --- test/Feature/WaveOps/WaveActiveCountBits.test | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Feature/WaveOps/WaveActiveCountBits.test b/test/Feature/WaveOps/WaveActiveCountBits.test index 231c56a96..cd250d95c 100644 --- a/test/Feature/WaveOps/WaveActiveCountBits.test +++ b/test/Feature/WaveOps/WaveActiveCountBits.test @@ -9,12 +9,12 @@ void main(uint3 threadID : SV_DispatchThreadID) { if (value[threadID.x] == 0) // thread 0 and 1. Out[threadID.x + 4] = WaveActiveCountBits(true); // thread 0-3 set b1 to true. - if (value[threadID.x] == 0 || value[threadID.x] == 1) + if (value[threadID.x] == 0 || value[threadID.x] == 2) B1 = true; - else // thread 4 + else // thread 3 Out[threadID.x + 4] = WaveActiveCountBits(false); - // Out[threadID.x + 4] on thread 3 is never written to and should remain 0. + // Out[threadID.x + 4] on thread 4 is never written to and should remain 0. // All threads count 3 because B1 set to true for threads 0,1, and 3. uint Count = WaveActiveCountBits(B1); Out[threadID.x] = Count;