Skip to content

Commit 63af5fb

Browse files
Add device index support to DML provider option (microsoft#1495)
Add device index support to DML provider option. Its useful feature to force a particular device on a multi GPU systems. genAI is missing the device index support
1 parent b4740f1 commit 63af5fb

File tree

5 files changed

+46
-15
lines changed

5 files changed

+46
-15
lines changed

src/dml/dml_helpers.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
2121
return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id);
2222
};
2323

24-
static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid = nullptr) {
24+
static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid = nullptr,
25+
uint32_t* deviceIndex = nullptr) {
2526
ComPtr<IDXGIFactory4> dxgi_factory;
2627
THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));
2728

2829
std::vector<ComPtr<IDXGIAdapter1>> adapter_infos;
2930

3031
ComPtr<IDXGIFactory6> dxgi_factory6;
31-
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid) {
32+
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid && !deviceIndex) {
3233
// Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later.
3334
ComPtr<IDXGIAdapter1> adapter;
3435
for (uint32_t adapter_index = 0;
@@ -52,7 +53,7 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid =
5253
adapter_infos.emplace_back(std::move(adapter));
5354
}
5455
}
55-
} else {
56+
} else if (device_luid) {
5657
// Enumerate adapters without ordering.
5758
ComPtr<IDXGIAdapter1> adapter;
5859
for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) {
@@ -66,7 +67,7 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid =
6667
ComPtr<ID3D12Device> d3d12_device;
6768
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));
6869

69-
if (d3d12_device && device_luid) {
70+
if (d3d12_device) {
7071
DXGI_ADAPTER_DESC1 description = {};
7172
THROW_IF_FAILED(adapter->GetDesc1(&description));
7273

@@ -75,24 +76,49 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid =
7576
adapter_infos.emplace_back(std::move(adapter));
7677
break;
7778
}
79+
}
80+
}
81+
} else {
82+
// Enumerate adapters without ordering.
83+
ComPtr<IDXGIAdapter1> adapter;
84+
uint32_t hwAdapterIndex = 0;
85+
for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) {
86+
// We can't assume the ordering of hardware and software adapters, so keep looping. This path should only execute on Windows 10
87+
// version 1709 or earlier; IDD (e.g. remote desktop) adapters do not exist when taking this code path.
88+
if (IsSoftwareAdapter(adapter.Get())) {
89+
continue;
90+
}
91+
92+
// Make sure that we are able to create the device
93+
ComPtr<ID3D12Device> d3d12_device;
94+
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));
95+
96+
if (d3d12_device && deviceIndex) {
97+
// get device specified by the deviceIndex
98+
if (*deviceIndex == hwAdapterIndex) {
99+
adapter_infos.emplace_back(std::move(adapter));
100+
break;
101+
}
78102
} else if (d3d12_device) {
79103
adapter_infos.emplace_back(std::move(adapter));
80104
}
105+
106+
hwAdapterIndex++;
81107
}
82108
}
83109

84110
return adapter_infos;
85111
}
86112

87-
static ComPtr<IDXGIAdapter1> CreateAdapter(PLUID device_luid = nullptr) {
88-
auto filtered_adapters = EnumerateAdapters(device_luid);
113+
static ComPtr<IDXGIAdapter1> CreateAdapter(PLUID device_luid = nullptr, uint32_t* deviceIndex = nullptr) {
114+
auto filtered_adapters = EnumerateAdapters(device_luid, deviceIndex);
89115
if (filtered_adapters.empty()) {
90116
throw std::runtime_error("No adapter is available for DML.");
91117
}
92118
return filtered_adapters.front();
93119
}
94120

95-
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid) {
121+
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid, uint32_t* deviceIndex) {
96122
D3D12_COMMAND_QUEUE_DESC command_queue_description = {
97123
D3D12_COMMAND_LIST_TYPE_COMPUTE,
98124
0,
@@ -102,7 +128,7 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device
102128

103129
DmlObjects dml_objects;
104130

105-
auto adapter = CreateAdapter(device_luid);
131+
auto adapter = CreateAdapter(device_luid, deviceIndex);
106132
ComPtr<ID3D12SDKConfiguration1> d3d12_sdk_config;
107133
ComPtr<ID3D12DeviceFactory> d3d12_factory;
108134

src/dml/dml_helpers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct DmlObjects {
3131
};
3232

3333
namespace DmlHelpers {
34-
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr);
34+
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr, uint32_t* p_device_index = nullptr);
3535

3636
DmlReusedCommandListState BuildReusableCommandList(
3737
IDMLDevice* dml_device,

src/dml/interface.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ struct GpuMemory final : DeviceBuffer {
9696
};
9797

9898
struct InterfaceImpl : DeviceInterface {
99-
InterfaceImpl(LUID* p_device_luid) {
99+
InterfaceImpl(LUID* p_device_luid, uint32_t* p_device_index) {
100100
Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
101101
if (!dml_api_) {
102102
throw std::runtime_error("Unexpected nullptr getting OrtDmlApi");
103103
}
104104

105-
dml_objects_ = DmlHelpers::CreateDmlObjects(CurrentModulePath(), p_device_luid);
105+
dml_objects_ = DmlHelpers::CreateDmlObjects(CurrentModulePath(), p_device_luid, p_device_index);
106106

107107
constexpr auto directml_dll = "DirectML.dll";
108108
smart_directml_dll_ = wil::unique_hmodule{LoadLibraryEx(directml_dll, nullptr, 0)};
@@ -213,9 +213,9 @@ struct InterfaceImpl : DeviceInterface {
213213

214214
std::unique_ptr<Dml::InterfaceImpl> g_dml_device;
215215

216-
void InitDmlInterface(LUID* p_device_luid) {
216+
void InitDmlInterface(LUID* p_device_luid, uint32_t* p_device_index) {
217217
if (!g_dml_device)
218-
g_dml_device = std::make_unique<Dml::InterfaceImpl>(p_device_luid);
218+
g_dml_device = std::make_unique<Dml::InterfaceImpl>(p_device_luid, p_device_index);
219219
}
220220

221221
void SetDmlProvider(OrtSessionOptions& session_options) {

src/dml/interface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ typedef struct _LUID {
1010

1111
namespace Generators {
1212

13-
void InitDmlInterface(LUID* p_device_luid);
13+
void InitDmlInterface(LUID* p_device_luid, uint32_t* p_device_index);
1414
void SetDmlProvider(OrtSessionOptions& options);
1515

1616
DeviceInterface* GetDmlInterface();

src/models/model.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,17 +439,22 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
439439
if (!GetDmlInterface()) {
440440
LUID device_luid{};
441441
LUID* p_device_luid{};
442+
uint32_t device_index{};
443+
uint32_t* p_device_index{};
442444
for (const auto& [name, value] : provider_options.options) {
443445
if (name == "luid") {
444446
if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
445447
device_luid.HighPart = std::stol(value.substr(0, separator_position));
446448
device_luid.LowPart = std::stol(value.substr(separator_position + 1));
447449
p_device_luid = &device_luid;
448450
}
451+
} else if (name == "device_index") {
452+
device_index = std::stoi(value);
453+
p_device_index = &device_index;
449454
}
450455
}
451456

452-
InitDmlInterface(p_device_luid);
457+
InitDmlInterface(p_device_luid, p_device_index);
453458
}
454459

455460
if (!disable_graph_capture) {

0 commit comments

Comments
 (0)