Skip to content

Commit 9636d23

Browse files
author
Ryan Lai
authored
Separate Device Creation into separate .h and .cpp file (#266)
* moved fetch types to commandlineargs * Separate device creation to outside of run method * small change for return code * CHange to auto in loop for consistency * Remove commas from metadata
1 parent 78c4b65 commit 9636d23

File tree

10 files changed

+456
-389
lines changed

10 files changed

+456
-389
lines changed

Tools/WinMLRunner/WinMLRunnerStaticLib.vcxproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@
3535
<ClInclude Include="src/Run.h" />
3636
<ClInclude Include="src/TimerHelper.h" />
3737
<ClInclude Include="src/TypeHelper.h" />
38+
<ClInclude Include="src\LearningModelDeviceHelper.h" />
3839
</ItemGroup>
3940
<ItemGroup>
4041
<ClCompile Include="src/CommandLineArgs.cpp" />
4142
<ClCompile Include="src/dllload.cpp" />
4243
<ClCompile Include="src/Filehelper.cpp" />
4344
<ClCompile Include="src/Run.cpp" />
45+
<ClCompile Include="src\LearningModelDeviceHelper.cpp" />
4446
</ItemGroup>
4547
<PropertyGroup Label="Globals">
4648
<VCProjectVersion>15.0</VCProjectVersion>

Tools/WinMLRunner/WinMLRunnerStaticLib.vcxproj.filters

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
<ClCompile Include="src/Run.cpp">
1414
<Filter>Source Files</Filter>
1515
</ClCompile>
16+
<ClCompile Include="src\LearningModelDeviceHelper.cpp">
17+
<Filter>Source Files</Filter>
18+
</ClCompile>
1619
</ItemGroup>
1720
<ItemGroup>
1821
<ClInclude Include="src/BindingUtilities.h">
@@ -39,6 +42,9 @@
3942
<ClInclude Include="src/Run.h">
4043
<Filter>Header Files</Filter>
4144
</ClInclude>
45+
<ClInclude Include="src\LearningModelDeviceHelper.h">
46+
<Filter>Header Files</Filter>
47+
</ClInclude>
4248
</ItemGroup>
4349
<ItemGroup>
4450
<Filter Include="Header Files">

Tools/WinMLRunner/src/CommandLineArgs.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,93 @@ void CommandLineArgs::CheckForInvalidArguments()
504504
throw hresult_not_implemented(L"Saving tensor output for multiple images isn't implemented.");
505505
}
506506
}
507+
508+
std::vector<InputDataType> CommandLineArgs::FetchInputDataTypes()
509+
{
510+
std::vector<InputDataType> inputDataTypes;
511+
512+
if (this->UseTensor())
513+
{
514+
inputDataTypes.push_back(InputDataType::Tensor);
515+
}
516+
517+
if (this->UseRGB())
518+
{
519+
inputDataTypes.push_back(InputDataType::ImageRGB);
520+
}
521+
522+
if (this->UseBGR())
523+
{
524+
inputDataTypes.push_back(InputDataType::ImageBGR);
525+
}
526+
527+
return inputDataTypes;
528+
}
529+
530+
std::vector<DeviceType> CommandLineArgs::FetchDeviceTypes()
531+
{
532+
std::vector<DeviceType> deviceTypes;
533+
534+
if (this->UseCPU())
535+
{
536+
deviceTypes.push_back(DeviceType::CPU);
537+
}
538+
539+
if (this->UseGPU())
540+
{
541+
deviceTypes.push_back(DeviceType::DefaultGPU);
542+
}
543+
544+
if (this->IsUsingGPUHighPerformance())
545+
{
546+
deviceTypes.push_back(DeviceType::HighPerfGPU);
547+
}
548+
549+
if (this->IsUsingGPUMinPower())
550+
{
551+
deviceTypes.push_back(DeviceType::MinPowerGPU);
552+
}
553+
554+
return deviceTypes;
555+
}
556+
557+
std::vector<InputBindingType> CommandLineArgs::FetchInputBindingTypes()
558+
{
559+
std::vector<InputBindingType> inputBindingTypes;
560+
561+
if (this->UseCPUBoundInput())
562+
{
563+
inputBindingTypes.push_back(InputBindingType::CPU);
564+
}
565+
566+
if (this->IsUsingGPUBoundInput())
567+
{
568+
inputBindingTypes.push_back(InputBindingType::GPU);
569+
}
570+
571+
return inputBindingTypes;
572+
}
573+
574+
std::vector<DeviceCreationLocation> CommandLineArgs::FetchDeviceCreationLocations()
575+
{
576+
std::vector<DeviceCreationLocation> deviceCreationLocations;
577+
578+
if (this->CreateDeviceInWinML())
579+
{
580+
deviceCreationLocations.push_back(DeviceCreationLocation::WinML);
581+
}
582+
583+
if (this->IsCreateDeviceOnClient())
584+
{
585+
deviceCreationLocations.push_back(DeviceCreationLocation::UserD3DDevice);
586+
}
587+
588+
return deviceCreationLocations;
589+
}
590+
void CommandLineArgs::AddPerformanceFileMetadata(const std::string& key, const std::string& value)
591+
{
592+
// remove commas that may affect processing of CSV
593+
std::string cleanedValue(value.size(), '0');
594+
cleanedValue.erase(std::remove_copy(value.begin(), value.end(), cleanedValue.begin(), ','), cleanedValue.end());
595+
m_perfFileMetadata.push_back(std::make_pair(key, cleanedValue));
596+
}

Tools/WinMLRunner/src/CommandLineArgs.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ class TensorizeArgs
1818
std::vector<float> StdDevs;
1919
} Normalize;
2020

21-
TensorizeArgs() : Func(TensorizeFuncs::Identity)
22-
{
23-
Normalize.Scale = 1.0f;
24-
};
21+
TensorizeArgs() : Func(TensorizeFuncs::Identity) { Normalize.Scale = 1.0f; };
2522
};
2623

2724
class CommandLineArgs
@@ -140,10 +137,8 @@ class CommandLineArgs
140137
void SetRunIterations(const uint32_t iterations) { m_numIterations = iterations; }
141138
void SetSessionCreationIterations(const uint32_t iterations) { m_numSessionIterations = iterations; }
142139
void SetLoadIterations(const uint32_t iterations) { m_numLoadIterations = iterations; }
143-
void AddPerformanceFileMetadata(const std::string& key, const std::string& value)
144-
{
145-
m_perfFileMetadata.push_back(std::make_pair(key, value));
146-
}
140+
void AddPerformanceFileMetadata(const std::string& key, const std::string& value);
141+
147142
// Stop iterating when total time of iterations after the first iteration exceeds time limit.
148143
void SetIterationTimeLimit(const double milliseconds)
149144
{
@@ -152,6 +147,11 @@ class CommandLineArgs
152147
}
153148
std::wstring SaveTensorMode() const { return m_saveTensorMode; }
154149

150+
std::vector<InputBindingType> FetchInputBindingTypes();
151+
std::vector<DeviceType> FetchDeviceTypes();
152+
std::vector<DeviceCreationLocation> FetchDeviceCreationLocations();
153+
std::vector<InputDataType> FetchInputDataTypes();
154+
155155
private:
156156
bool m_perfCapture = false;
157157
bool m_perfConsoleOutputAll = false;
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
#include "LearningModelDeviceHelper.h"
2+
#include "TypeHelper.h"
3+
#include "Common.h"
4+
#include "d3d11.h"
5+
#include "d3dx12.h"
6+
#include <Windows.Graphics.DirectX.Direct3D11.interop.h>
7+
#include "Windows.AI.MachineLearning.Native.h"
8+
#include <codecvt>
9+
using namespace winrt::Windows::Graphics::DirectX::Direct3D11;
10+
11+
#ifdef DXCORE_SUPPORTED_BUILD
12+
HRESULT CreateDXGIFactory2SEH(void** dxgiFactory)
13+
{
14+
// Recover from delay-load module failure.
15+
HRESULT hr;
16+
__try
17+
{
18+
hr = CreateDXGIFactory2(0, __uuidof(IDXGIFactory4), dxgiFactory);
19+
}
20+
__except (GetExceptionCode() == VcppException(ERROR_SEVERITY_ERROR, ERROR_MOD_NOT_FOUND)
21+
? EXCEPTION_EXECUTE_HANDLER
22+
: EXCEPTION_CONTINUE_SEARCH)
23+
{
24+
hr = HRESULT_FROM_WIN32(ERROR_MOD_NOT_FOUND);
25+
}
26+
return hr;
27+
}
28+
#endif
29+
30+
void PopulateLearningModelDeviceList(CommandLineArgs& args, std::vector<LearningModelDeviceWithMetadata>& deviceList)
31+
{
32+
std::vector<DeviceType> deviceTypes = args.FetchDeviceTypes();
33+
std::vector<DeviceCreationLocation> deviceCreationLocations = args.FetchDeviceCreationLocations();
34+
for (auto deviceType : deviceTypes)
35+
{
36+
for (auto deviceCreationLocation : deviceCreationLocations)
37+
{
38+
try
39+
{
40+
#ifdef DXCORE_SUPPORTED_BUILD
41+
const std::wstring& adapterName = args.GetGPUAdapterName();
42+
#endif
43+
if (deviceCreationLocation == DeviceCreationLocation::UserD3DDevice && deviceType != DeviceType::CPU)
44+
{
45+
// Enumerate Adapters to pick the requested one.
46+
com_ptr<IDXGIFactory6> factory;
47+
HRESULT hr = CreateDXGIFactory(__uuidof(IDXGIFactory6), factory.put_void());
48+
THROW_IF_FAILED(hr);
49+
50+
com_ptr<IDXGIAdapter> adapter;
51+
switch (deviceType)
52+
{
53+
case DeviceType::DefaultGPU:
54+
hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_UNSPECIFIED,
55+
__uuidof(IDXGIAdapter), adapter.put_void());
56+
break;
57+
case DeviceType::MinPowerGPU:
58+
hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_MINIMUM_POWER,
59+
__uuidof(IDXGIAdapter), adapter.put_void());
60+
break;
61+
case DeviceType::HighPerfGPU:
62+
hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE,
63+
__uuidof(IDXGIAdapter), adapter.put_void());
64+
break;
65+
default:
66+
throw hresult(E_INVALIDARG);
67+
}
68+
THROW_IF_FAILED(hr);
69+
70+
// Creating the device on the client and using it to create the video frame and initialize the
71+
// session makes sure that everything is on the same device. This usually avoids an expensive
72+
// cross-device and cross-videoframe copy via the VideoFrame pipeline.
73+
com_ptr<ID3D11Device> d3d11Device;
74+
hr = D3D11CreateDevice(adapter.get(), D3D_DRIVER_TYPE_UNKNOWN, nullptr,
75+
D3D11_CREATE_DEVICE_BGRA_SUPPORT, nullptr, 0, D3D11_SDK_VERSION,
76+
d3d11Device.put(), nullptr, nullptr);
77+
THROW_IF_FAILED(hr);
78+
79+
com_ptr<IDXGIDevice> dxgiDevice;
80+
hr = d3d11Device->QueryInterface(__uuidof(IDXGIDevice), dxgiDevice.put_void());
81+
THROW_IF_FAILED(hr);
82+
83+
com_ptr<IInspectable> inspectableDevice;
84+
hr = CreateDirect3D11DeviceFromDXGIDevice(dxgiDevice.get(), inspectableDevice.put());
85+
THROW_IF_FAILED(hr);
86+
deviceList.push_back({
87+
LearningModelDevice::CreateFromDirect3D11Device(inspectableDevice.as<IDirect3DDevice>()),
88+
deviceType,
89+
deviceCreationLocation
90+
});
91+
}
92+
#ifdef DXCORE_SUPPORTED_BUILD
93+
else if ((TypeHelper::GetWinmlDeviceKind(deviceType) != LearningModelDeviceKind::Cpu) &&
94+
!adapterName.empty())
95+
{
96+
com_ptr<IDXCoreAdapterFactory> spFactory;
97+
THROW_IF_FAILED(DXCoreCreateAdapterFactory(IID_PPV_ARGS(spFactory.put())));
98+
99+
com_ptr<IDXCoreAdapterList> spAdapterList;
100+
const GUID dxGUIDs[] = { DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
101+
102+
THROW_IF_FAILED(
103+
spFactory->CreateAdapterList(ARRAYSIZE(dxGUIDs), dxGUIDs, IID_PPV_ARGS(spAdapterList.put())));
104+
105+
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
106+
std::string adapterNameStr = converter.to_bytes(adapterName);
107+
com_ptr<IDXCoreAdapter> spAdapter = nullptr;
108+
com_ptr<IDXCoreAdapter> currAdapter = nullptr;
109+
bool chosenAdapterFound = false;
110+
printf("Printing available adapters..\n");
111+
for (UINT i = 0; i < spAdapterList->GetAdapterCount(); i++)
112+
{
113+
THROW_IF_FAILED(spAdapterList->GetAdapter(i, currAdapter.put()));
114+
115+
// If the adapter is a software adapter then don't consider it for index selection
116+
bool isHardware;
117+
size_t driverDescriptionSize;
118+
THROW_IF_FAILED(currAdapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription,
119+
&driverDescriptionSize));
120+
CHAR* driverDescription = new CHAR[driverDescriptionSize];
121+
THROW_IF_FAILED(currAdapter->GetProperty(DXCoreAdapterProperty::IsHardware, &isHardware));
122+
THROW_IF_FAILED(currAdapter->GetProperty(DXCoreAdapterProperty::DriverDescription,
123+
driverDescriptionSize, driverDescription));
124+
if (isHardware)
125+
{
126+
printf("Description: %s\n", driverDescription);
127+
}
128+
if (!adapterName.empty() && !chosenAdapterFound)
129+
{
130+
std::string driverDescriptionStr = std::string(driverDescription);
131+
std::transform(driverDescriptionStr.begin(), driverDescriptionStr.end(),
132+
driverDescriptionStr.begin(), ::tolower);
133+
std::transform(adapterNameStr.begin(), adapterNameStr.end(), adapterNameStr.begin(),
134+
::tolower);
135+
if (strstr(driverDescriptionStr.c_str(), adapterNameStr.c_str()))
136+
{
137+
chosenAdapterFound = true;
138+
spAdapter = currAdapter;
139+
}
140+
}
141+
currAdapter = nullptr;
142+
free(driverDescription);
143+
}
144+
145+
if (spAdapter == nullptr)
146+
{
147+
throw hresult_invalid_argument(L"ERROR: No matching adapter with given adapter name: " +
148+
adapterName);
149+
}
150+
size_t driverDescriptionSize;
151+
THROW_IF_FAILED(
152+
spAdapter->GetPropertySize(DXCoreAdapterProperty::DriverDescription, &driverDescriptionSize));
153+
CHAR* driverDescription = new CHAR[driverDescriptionSize];
154+
spAdapter->GetProperty(DXCoreAdapterProperty::DriverDescription, driverDescriptionSize,
155+
driverDescription);
156+
printf("Using adapter : %s\n", driverDescription);
157+
free(driverDescription);
158+
IUnknown* pAdapter = spAdapter.get();
159+
com_ptr<IDXGIAdapter> spDxgiAdapter;
160+
D3D_FEATURE_LEVEL d3dFeatureLevel = D3D_FEATURE_LEVEL_1_0_CORE;
161+
D3D12_COMMAND_LIST_TYPE commandQueueType = D3D12_COMMAND_LIST_TYPE_COMPUTE;
162+
163+
// Check if adapter selected has DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS attribute selected. If
164+
// so, then GPU was selected that has D3D12 and D3D11 capabilities. It would be the most stable
165+
// to use DXGI to enumerate GPU and use D3D_FEATURE_LEVEL_11_0 so that image tensorization for
166+
// video frames would be able to happen on the GPU.
167+
if (spAdapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS))
168+
{
169+
d3dFeatureLevel = D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0;
170+
com_ptr<IDXGIFactory4> dxgiFactory4;
171+
HRESULT hr;
172+
try
173+
{
174+
hr = CreateDXGIFactory2SEH(dxgiFactory4.put_void());
175+
}
176+
catch (...)
177+
{
178+
hr = E_FAIL;
179+
}
180+
if (hr == S_OK)
181+
{
182+
// If DXGI factory creation was successful then get the IDXGIAdapter from the LUID
183+
// acquired from the selectedAdapter
184+
std::cout << "Using DXGI for adapter creation.." << std::endl;
185+
LUID adapterLuid;
186+
THROW_IF_FAILED(spAdapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &adapterLuid));
187+
THROW_IF_FAILED(dxgiFactory4->EnumAdapterByLuid(adapterLuid, __uuidof(IDXGIAdapter),
188+
spDxgiAdapter.put_void()));
189+
pAdapter = spDxgiAdapter.get();
190+
}
191+
}
192+
193+
// create D3D12Device
194+
com_ptr<ID3D12Device> d3d12Device;
195+
THROW_IF_FAILED(
196+
D3D12CreateDevice(pAdapter, d3dFeatureLevel, __uuidof(ID3D12Device), d3d12Device.put_void()));
197+
198+
// create D3D12 command queue from device
199+
com_ptr<ID3D12CommandQueue> d3d12CommandQueue;
200+
D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
201+
commandQueueDesc.Type = commandQueueType;
202+
THROW_IF_FAILED(d3d12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue),
203+
d3d12CommandQueue.put_void()));
204+
205+
// create LearningModelDevice from command queue
206+
auto factory = get_activation_factory<LearningModelDevice, ILearningModelDeviceFactoryNative>();
207+
com_ptr<::IUnknown> spUnkLearningModelDevice;
208+
THROW_IF_FAILED(
209+
factory->CreateFromD3D12CommandQueue(d3d12CommandQueue.get(), spUnkLearningModelDevice.put()));
210+
deviceList.push_back({
211+
spUnkLearningModelDevice.as<LearningModelDevice>(),
212+
deviceType,
213+
deviceCreationLocation
214+
});
215+
}
216+
#endif
217+
else
218+
{
219+
deviceList.push_back({
220+
LearningModelDevice(TypeHelper::GetWinmlDeviceKind(deviceType)),
221+
deviceType,
222+
deviceCreationLocation
223+
});
224+
}
225+
}
226+
catch (...)
227+
{
228+
printf("Creating LearningModelDevice failed!");
229+
throw;
230+
}
231+
}
232+
}
233+
}

0 commit comments

Comments
 (0)