Skip to content

Commit 82f88f8

Browse files
azhirnovTheMostDiligent
authored andcommitted
Metal: added WaveOp test
1 parent 568135a commit 82f88f8

File tree

1 file changed

+138
-1
lines changed

1 file changed

+138
-1
lines changed

Tests/DiligentCoreAPITest/src/WaveOpTest.cpp

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ void main()
264264
{
265265
vec4 temp = vec4(float(DTid));
266266
vec4 blendWith = subgroupShuffle(temp, (DTid + 5) & 7);
267-
Accum += (dot(temp, temp) < 0.0 ? 1 : 0);
267+
Accum += (dot(blendWith, blendWith) < 0.0 ? 1 : 0);
268268
}
269269
#endif
270270
#if WAVE_FEATURE_SHUFFLE_RELATIVE
@@ -320,4 +320,141 @@ void main()
320320
ASSERT_NE(pPSO, nullptr);
321321
}
322322

323+
324+
TEST(WaveOpTest, CompileShader_MSL)
325+
{
326+
auto* const pEnv = TestingEnvironment::GetInstance();
327+
auto* const pDevice = pEnv->GetDevice();
328+
const auto& DeviceInfo = pDevice->GetDeviceInfo();
329+
330+
if (!DeviceInfo.IsMetalDevice())
331+
{
332+
GTEST_SKIP();
333+
}
334+
if (!DeviceInfo.Features.WaveOp)
335+
{
336+
GTEST_SKIP() << "Wave operations are not supported by this device";
337+
}
338+
339+
TestingEnvironment::ScopedReset EnvironmentAutoReset;
340+
341+
const auto& WaveOpProps = pDevice->GetAdapterInfo().WaveOp;
342+
343+
ASSERT_NE(WaveOpProps.Features, WAVE_FEATURE_UNKNOWN);
344+
ASSERT_TRUE((WaveOpProps.Features & WAVE_FEATURE_BASIC) != 0);
345+
346+
ASSERT_NE(WaveOpProps.SupportedStages, SHADER_TYPE_UNKNOWN);
347+
ASSERT_TRUE((WaveOpProps.SupportedStages & SHADER_TYPE_COMPUTE) != 0);
348+
349+
ASSERT_GT(WaveOpProps.MinSize, 0u);
350+
ASSERT_GE(WaveOpProps.MaxSize, WaveOpProps.MinSize);
351+
352+
std::stringstream ShaderSourceStream;
353+
// clang-format off
354+
ShaderSourceStream << "#define WAVE_FEATURE_BASIC " << int{(WaveOpProps.Features & WAVE_FEATURE_BASIC) != 0} << "\n";
355+
ShaderSourceStream << "#define WAVE_FEATURE_VOTE " << int{(WaveOpProps.Features & WAVE_FEATURE_VOTE) != 0} << "\n";
356+
ShaderSourceStream << "#define WAVE_FEATURE_ARITHMETIC " << int{(WaveOpProps.Features & WAVE_FEATURE_ARITHMETIC) != 0} << "\n";
357+
ShaderSourceStream << "#define WAVE_FEATURE_BALLOUT " << int{(WaveOpProps.Features & WAVE_FEATURE_BALLOUT) != 0} << "\n";
358+
ShaderSourceStream << "#define WAVE_FEATURE_SHUFFLE " << int{(WaveOpProps.Features & WAVE_FEATURE_SHUFFLE) != 0} << "\n";
359+
ShaderSourceStream << "#define WAVE_FEATURE_SHUFFLE_RELATIVE " << int{(WaveOpProps.Features & WAVE_FEATURE_SHUFFLE_RELATIVE) != 0} << "\n";
360+
ShaderSourceStream << "#define WAVE_FEATURE_CLUSTERED " << int{(WaveOpProps.Features & WAVE_FEATURE_CLUSTERED) != 0} << "\n";
361+
ShaderSourceStream << "#define WAVE_FEATURE_QUAD " << int{(WaveOpProps.Features & WAVE_FEATURE_QUAD) != 0} << "\n";
362+
// clang-format on
363+
364+
static const char ShaderBody[] = R"(
365+
#include <metal_stdlib>
366+
#include <simd/simd.h>
367+
#include <metal_simdgroup>
368+
using namespace metal;
369+
370+
kernel void CSMain(
371+
#if WAVE_FEATURE_BASIC
372+
uint LaneIndex [[thread_index_in_simdgroup]],
373+
uint WaveSize [[threads_per_simdgroup]],
374+
#endif
375+
#if WAVE_FEATURE_QUAD
376+
uint QuadId [[thread_index_in_quadgroup]],
377+
#endif
378+
device uint* g_WBuffer [[buffer(0)]],
379+
uint DTid [[thread_index_in_threadgroup]]
380+
)
381+
{
382+
uint Accum = 0;
383+
#if WAVE_FEATURE_BASIC
384+
{
385+
Accum += (LaneIndex % WaveSize);
386+
}
387+
#endif
388+
#if WAVE_FEATURE_VOTE
389+
{
390+
if (simd_all(Accum > 0xFFFF))
391+
Accum += 1;
392+
}
393+
#endif
394+
#if WAVE_FEATURE_ARITHMETIC
395+
{
396+
uint sum = simd_sum(DTid);
397+
Accum += (sum & 1);
398+
}
399+
#endif
400+
#if WAVE_FEATURE_BALLOUT
401+
{
402+
float val = simd_broadcast(float(DTid) * 0.1f, ushort(LaneIndex));
403+
Accum += (val > 3.5f);
404+
}
405+
#endif
406+
#if WAVE_FEATURE_SHUFFLE
407+
{
408+
float4 temp = float4(float(DTid));
409+
float4 blendWith = simd_shuffle(temp, ushort((DTid + 5) & 7));
410+
Accum += (dot(blendWith, blendWith) < 0.0 ? 1 : 0);
411+
}
412+
#endif
413+
#if WAVE_FEATURE_SHUFFLE_RELATIVE
414+
{
415+
float4 temp = float4(float(DTid));
416+
for (uint i = 2; i < WaveSize; i *= 2)
417+
{
418+
float4 other = simd_shuffle_up(temp, ushort(i));
419+
420+
if (i <= LaneIndex)
421+
temp = temp * other;
422+
}
423+
Accum += (dot(temp, temp) > 0.5 ? 1 : 0);
424+
}
425+
#endif
426+
#if WAVE_FEATURE_QUAD
427+
{
428+
float val = quad_broadcast(float(DTid) * 0.1f, ushort(LaneIndex));
429+
Accum += (val > 2.5f);
430+
}
431+
#endif
432+
433+
g_WBuffer[DTid] = Accum;
434+
}
435+
)";
436+
437+
ShaderSourceStream << ShaderBody;
438+
const String Source = ShaderSourceStream.str();
439+
440+
ShaderCreateInfo ShaderCI;
441+
ShaderCI.SourceLanguage = SHADER_SOURCE_LANGUAGE_MSL;
442+
ShaderCI.Desc.ShaderType = SHADER_TYPE_COMPUTE;
443+
ShaderCI.Desc.Name = "Wave op test - CS";
444+
ShaderCI.EntryPoint = "CSMain";
445+
ShaderCI.Source = Source.c_str();
446+
447+
RefCntAutoPtr<IShader> pCS;
448+
pDevice->CreateShader(ShaderCI, &pCS);
449+
ASSERT_NE(pCS, nullptr);
450+
451+
ComputePipelineStateCreateInfo PSOCreateInfo;
452+
PSOCreateInfo.PSODesc.Name = "Wave op test";
453+
PSOCreateInfo.pCS = pCS;
454+
455+
RefCntAutoPtr<IPipelineState> pPSO;
456+
pDevice->CreateComputePipelineState(PSOCreateInfo, &pPSO);
457+
ASSERT_NE(pPSO, nullptr);
458+
}
459+
323460
} // namespace Diligent

0 commit comments

Comments
 (0)