@@ -264,7 +264,7 @@ void main()
264264 {
265265 vec4 temp = vec4(float(DTid));
266266 vec4 blendWith = subgroupShuffle(temp, (DTid + 5) & 7);
267- Accum += (dot(temp, temp ) < 0.0 ? 1 : 0);
267+ Accum += (dot(blendWith, blendWith ) < 0.0 ? 1 : 0);
268268 }
269269 #endif
270270 #if WAVE_FEATURE_SHUFFLE_RELATIVE
@@ -320,4 +320,141 @@ void main()
320320 ASSERT_NE (pPSO, nullptr );
321321}
322322
323+
324+ TEST (WaveOpTest, CompileShader_MSL)
325+ {
326+ auto * const pEnv = TestingEnvironment::GetInstance ();
327+ auto * const pDevice = pEnv->GetDevice ();
328+ const auto & DeviceInfo = pDevice->GetDeviceInfo ();
329+
330+ if (!DeviceInfo.IsMetalDevice ())
331+ {
332+ GTEST_SKIP ();
333+ }
334+ if (!DeviceInfo.Features .WaveOp )
335+ {
336+ GTEST_SKIP () << " Wave operations are not supported by this device" ;
337+ }
338+
339+ TestingEnvironment::ScopedReset EnvironmentAutoReset;
340+
341+ const auto & WaveOpProps = pDevice->GetAdapterInfo ().WaveOp ;
342+
343+ ASSERT_NE (WaveOpProps.Features , WAVE_FEATURE_UNKNOWN);
344+ ASSERT_TRUE ((WaveOpProps.Features & WAVE_FEATURE_BASIC) != 0 );
345+
346+ ASSERT_NE (WaveOpProps.SupportedStages , SHADER_TYPE_UNKNOWN);
347+ ASSERT_TRUE ((WaveOpProps.SupportedStages & SHADER_TYPE_COMPUTE) != 0 );
348+
349+ ASSERT_GT (WaveOpProps.MinSize , 0u );
350+ ASSERT_GE (WaveOpProps.MaxSize , WaveOpProps.MinSize );
351+
352+ std::stringstream ShaderSourceStream;
353+ // clang-format off
354+ ShaderSourceStream << " #define WAVE_FEATURE_BASIC " << int {(WaveOpProps.Features & WAVE_FEATURE_BASIC) != 0 } << " \n " ;
355+ ShaderSourceStream << " #define WAVE_FEATURE_VOTE " << int {(WaveOpProps.Features & WAVE_FEATURE_VOTE) != 0 } << " \n " ;
356+ ShaderSourceStream << " #define WAVE_FEATURE_ARITHMETIC " << int {(WaveOpProps.Features & WAVE_FEATURE_ARITHMETIC) != 0 } << " \n " ;
357+ ShaderSourceStream << " #define WAVE_FEATURE_BALLOUT " << int {(WaveOpProps.Features & WAVE_FEATURE_BALLOUT) != 0 } << " \n " ;
358+ ShaderSourceStream << " #define WAVE_FEATURE_SHUFFLE " << int {(WaveOpProps.Features & WAVE_FEATURE_SHUFFLE) != 0 } << " \n " ;
359+ ShaderSourceStream << " #define WAVE_FEATURE_SHUFFLE_RELATIVE " << int {(WaveOpProps.Features & WAVE_FEATURE_SHUFFLE_RELATIVE) != 0 } << " \n " ;
360+ ShaderSourceStream << " #define WAVE_FEATURE_CLUSTERED " << int {(WaveOpProps.Features & WAVE_FEATURE_CLUSTERED) != 0 } << " \n " ;
361+ ShaderSourceStream << " #define WAVE_FEATURE_QUAD " << int {(WaveOpProps.Features & WAVE_FEATURE_QUAD) != 0 } << " \n " ;
362+ // clang-format on
363+
364+ static const char ShaderBody[] = R"(
365+ #include <metal_stdlib>
366+ #include <simd/simd.h>
367+ #include <metal_simdgroup>
368+ using namespace metal;
369+
370+ kernel void CSMain(
371+ #if WAVE_FEATURE_BASIC
372+ uint LaneIndex [[thread_index_in_simdgroup]],
373+ uint WaveSize [[threads_per_simdgroup]],
374+ #endif
375+ #if WAVE_FEATURE_QUAD
376+ uint QuadId [[thread_index_in_quadgroup]],
377+ #endif
378+ device uint* g_WBuffer [[buffer(0)]],
379+ uint DTid [[thread_index_in_threadgroup]]
380+ )
381+ {
382+ uint Accum = 0;
383+ #if WAVE_FEATURE_BASIC
384+ {
385+ Accum += (LaneIndex % WaveSize);
386+ }
387+ #endif
388+ #if WAVE_FEATURE_VOTE
389+ {
390+ if (simd_all(Accum > 0xFFFF))
391+ Accum += 1;
392+ }
393+ #endif
394+ #if WAVE_FEATURE_ARITHMETIC
395+ {
396+ uint sum = simd_sum(DTid);
397+ Accum += (sum & 1);
398+ }
399+ #endif
400+ #if WAVE_FEATURE_BALLOUT
401+ {
402+ float val = simd_broadcast(float(DTid) * 0.1f, ushort(LaneIndex));
403+ Accum += (val > 3.5f);
404+ }
405+ #endif
406+ #if WAVE_FEATURE_SHUFFLE
407+ {
408+ float4 temp = float4(float(DTid));
409+ float4 blendWith = simd_shuffle(temp, ushort((DTid + 5) & 7));
410+ Accum += (dot(blendWith, blendWith) < 0.0 ? 1 : 0);
411+ }
412+ #endif
413+ #if WAVE_FEATURE_SHUFFLE_RELATIVE
414+ {
415+ float4 temp = float4(float(DTid));
416+ for (uint i = 2; i < WaveSize; i *= 2)
417+ {
418+ float4 other = simd_shuffle_up(temp, ushort(i));
419+
420+ if (i <= LaneIndex)
421+ temp = temp * other;
422+ }
423+ Accum += (dot(temp, temp) > 0.5 ? 1 : 0);
424+ }
425+ #endif
426+ #if WAVE_FEATURE_QUAD
427+ {
428+ float val = quad_broadcast(float(DTid) * 0.1f, ushort(LaneIndex));
429+ Accum += (val > 2.5f);
430+ }
431+ #endif
432+
433+ g_WBuffer[DTid] = Accum;
434+ }
435+ )" ;
436+
437+ ShaderSourceStream << ShaderBody;
438+ const String Source = ShaderSourceStream.str ();
439+
440+ ShaderCreateInfo ShaderCI;
441+ ShaderCI.SourceLanguage = SHADER_SOURCE_LANGUAGE_MSL;
442+ ShaderCI.Desc .ShaderType = SHADER_TYPE_COMPUTE;
443+ ShaderCI.Desc .Name = " Wave op test - CS" ;
444+ ShaderCI.EntryPoint = " CSMain" ;
445+ ShaderCI.Source = Source.c_str ();
446+
447+ RefCntAutoPtr<IShader> pCS;
448+ pDevice->CreateShader (ShaderCI, &pCS);
449+ ASSERT_NE (pCS, nullptr );
450+
451+ ComputePipelineStateCreateInfo PSOCreateInfo;
452+ PSOCreateInfo.PSODesc .Name = " Wave op test" ;
453+ PSOCreateInfo.pCS = pCS;
454+
455+ RefCntAutoPtr<IPipelineState> pPSO;
456+ pDevice->CreateComputePipelineState (PSOCreateInfo, &pPSO);
457+ ASSERT_NE (pPSO, nullptr );
458+ }
459+
323460} // namespace Diligent
0 commit comments