Skip to content

Commit 0abe45d

Browse files
authored
Check WaveMatrix support; locally declare feature struct for now (microsoft#6089)
WaveMatrix wasn't checking WaveMMATier before running tests (assuming it was always supported for SM 6.8). This change fixes this. D3D12_SDK_VERSION check for WAVE_MMA feature structure and enum definitions assumed they will be defined in SDK version 613. This isn't accurate, so this block of local definitions will always be enabled until we have the correct version in the future.
1 parent 79d2b8d commit 0abe45d

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,12 @@ class ExecutionTest {
758758
return nullptr;
759759
}
760760

761+
if (!DoesDeviceSupportWaveMatrix(pDevice)) {
762+
LogCommentFmt(L"WaveMatrix not supported on this device.");
763+
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
764+
return nullptr;
765+
}
766+
761767
CComPtr<IStream> pStream;
762768
ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
763769

@@ -1630,6 +1636,19 @@ class ExecutionTest {
16301636
#endif
16311637
}
16321638

1639+
bool DoesDeviceSupportWaveMatrix(ID3D12Device *pDevice) {
1640+
#if defined(NTDDI_WIN10_FE) && WDK_NTDDI_VERSION >= NTDDI_WIN10_FE
1641+
D3D12_FEATURE_DATA_D3D12_OPTIONS9 O9;
1642+
if (FAILED(pDevice->CheckFeatureSupport(
1643+
(D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS9, &O9, sizeof(O9))))
1644+
return false;
1645+
return O9.WaveMMATier >= D3D12_WAVE_MMA_TIER_1_0;
1646+
#else
1647+
UNREFERENCED_PARAMETER(pDevice);
1648+
return false;
1649+
#endif
1650+
}
1651+
16331652
bool DoesDeviceSupportAdvancedTexOps(ID3D12Device *pDevice) {
16341653
#if defined(NTDDI_WIN10_CU) && WDK_NTDDI_VERSION >= NTDDI_WIN10_CU
16351654
D3D12_FEATURE_DATA_D3D12_OPTIONS14 O14;
@@ -9002,7 +9021,8 @@ void LoadStoreMat(int M, int N, bool LEFT, int MEM_TYPE, uint32_t K, uint32_t k,
90029021
}
90039022

90049023
// define WAVE_MMA types if building with SDK that does not support it yet
9005-
#if !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613)
9024+
// For now: Force this on, until we know the version.
9025+
#if 1 // !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613)
90069026
typedef enum D3D12_WAVE_MMA_INPUT_DATATYPE {
90079027
D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID = 0,
90089028
D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE =

0 commit comments

Comments
 (0)