Skip to content

Commit 27c591f

Browse files
vortex-captainYi Renjiaxuwu2021
authored
ProviderOptions level device filtering and APIs to configure model level device filtering (microsoft#1744)
# Goal This PR replaces the existing device filtering at the model configuration level by adding device filtering parameters to ProviderOptions level, as a first step to support running different parts of a GenAI model (e.g. one configured by decoder pipeline) on different CPU/GPU/NPU devices in the future. In addition, this PR adds the following APIs to allow setting/clearing device filtering options at the model level, which effectively updates `config.model.decoder.session_options.provider_options.device_filtering_options`, based on previous work of @jiaxuwu2021 . This capability is crucial for running the same GenAI model in different hardware environments by selecting the inference device programmatically, as well as automated testing of EPs. - `SetDecoderProviderOptionsHardwareDeviceType` (native and C#) `set_decoder_provider_options_hardware_device_type` (Python) - `SetDecoderProviderOptionsHardwareDeviceId` (native and C#) `set_decoder_provider_options_hardware_device_id` (Python) - `SetDecoderProviderOptionsHardwareVendorId` (native and C#) `set_decoder_provider_options_hardware_vendor_id` (Python) - `ClearDecoderProviderOptionsHardwareDeviceType` (native and C#) `clear_decoder_provider_options_hardware_device_type` (Python) - `ClearDecoderProviderOptionsHardwareDeviceId` (native and C#) `clear_decoder_provider_options_hardware_device_id` (Python) - `ClearDecoderProviderOptionsHardwareVendorId` (native and C#) `clear_decoder_provider_options_hardware_vendor_id` (Python) # Usage Old `genai_config.json` format (to be replaced by this PR) ```json { "model": { ... "hardware_device_type": "CPU", "hardware_device_id": 7, "hardware_vendor_id": 32902, ... } } ``` New `genai_config.json` format: ```json { "model": { ... "decoder": { "session_options": { ... "provider_options": [ { "OpenVINO": { ... "device_filtering_options": { "hardware_device_type": "CPU", "hardware_device_id": 7, "hardware_vendor_id": 32902 } } } ] }, ... }, ... }, ... } ``` # Unit tests The new APIs are tested in case `CAPITests` of `test/c_api_tests.cpp` --------- Co-authored-by: Yi Ren <[email protected]> Co-authored-by: jiaxuwu2021 <[email protected]>
1 parent 04c8090 commit 27c591f

File tree

10 files changed

+325
-33
lines changed

10 files changed

+325
-33
lines changed

src/config.cpp

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ ONNXTensorElementDataType TranslateTensorType(std::string_view value) {
4141
throw std::runtime_error("Invalid tensor type: " + std::string(value));
4242
}
4343

44+
OrtHardwareDeviceType ParseHardwareDeviceType(std::string_view value) {
45+
std::string lower_value(value);
46+
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
47+
[](unsigned char c) { return static_cast<unsigned char>(std::tolower(c)); });
48+
if (lower_value == "cpu") {
49+
return OrtHardwareDeviceType_CPU;
50+
} else if (lower_value == "gpu") {
51+
return OrtHardwareDeviceType_GPU;
52+
} else if (lower_value == "npu") {
53+
return OrtHardwareDeviceType_NPU;
54+
} else {
55+
throw std::runtime_error("Unsupported hardware device type: " + std::string(value));
56+
}
57+
}
58+
4459
struct NamedStrings_Element : JSON::Element {
4560
explicit NamedStrings_Element(std::vector<Config::NamedString>& v) : v_{v} {}
4661

@@ -63,26 +78,66 @@ struct Int_Array_Element : JSON::Element {
6378
std::vector<int>& v_;
6479
};
6580

81+
struct DeviceFilteringOptions_Element : JSON::Element {
82+
explicit DeviceFilteringOptions_Element(Config::DeviceFilteringOptions& v) : v_{v} {}
83+
84+
void OnValue(std::string_view name, JSON::Value value) override {
85+
if (name == "hardware_device_type") {
86+
v_.hardware_device_type = ParseHardwareDeviceType(JSON::Get<std::string_view>(value));
87+
} else if (name == "hardware_device_id") {
88+
v_.hardware_device_id = static_cast<uint32_t>(JSON::Get<double>(value));
89+
} else if (name == "hardware_vendor_id") {
90+
v_.hardware_vendor_id = static_cast<uint32_t>(JSON::Get<double>(value));
91+
} else {
92+
throw JSON::unknown_value_error{};
93+
}
94+
}
95+
96+
private:
97+
Config::DeviceFilteringOptions& v_;
98+
};
99+
100+
struct ProviderOptions_Element : JSON::Element {
101+
explicit ProviderOptions_Element(Config::ProviderOptions& v) : v_{v} {}
102+
103+
void OnValue(std::string_view name, JSON::Value value) override {
104+
v_.options.emplace_back(name, JSON::Get<std::string_view>(value));
105+
}
106+
107+
JSON::Element& OnObject(std::string_view name) override {
108+
if (name == "device_filtering_options") {
109+
v_.device_filtering_options = Config::DeviceFilteringOptions{};
110+
device_filtering_options_element_ = std::make_unique<DeviceFilteringOptions_Element>(*v_.device_filtering_options);
111+
return *device_filtering_options_element_;
112+
}
113+
throw JSON::unknown_value_error{};
114+
}
115+
116+
private:
117+
Config::ProviderOptions& v_;
118+
std::unique_ptr<DeviceFilteringOptions_Element> device_filtering_options_element_;
119+
};
120+
66121
struct ProviderOptionsObject_Element : JSON::Element {
67122
explicit ProviderOptionsObject_Element(std::vector<Config::ProviderOptions>& v) : v_{v} {}
68123

69124
JSON::Element& OnObject(std::string_view name) override {
70125
for (auto& v : v_) {
71126
if (v.name == name) {
72-
options_element_ = std::make_unique<NamedStrings_Element>(v.options);
127+
options_element_ = std::make_unique<ProviderOptions_Element>(v);
73128
return *options_element_;
74129
}
75130
}
76131

77132
auto& options = v_.emplace_back();
78133
options.name = name;
79-
options_element_ = std::make_unique<NamedStrings_Element>(options.options);
134+
options_element_ = std::make_unique<ProviderOptions_Element>(options);
80135
return *options_element_;
81136
}
82137

83138
private:
84139
std::vector<Config::ProviderOptions>& v_;
85-
std::unique_ptr<NamedStrings_Element> options_element_;
140+
std::unique_ptr<ProviderOptions_Element> options_element_;
86141
};
87142

88143
struct ProviderOptionsArray_Element : JSON::Element {
@@ -717,12 +772,6 @@ struct Model_Element : JSON::Element {
717772
v_.decoder_start_token_id = static_cast<int>(JSON::Get<double>(value));
718773
} else if (name == "sep_token_id") {
719774
v_.sep_token_id = static_cast<int>(JSON::Get<double>(value));
720-
} else if (name == "hardware_device_type") {
721-
v_.hardware_device_type = JSON::Get<std::string_view>(value);
722-
} else if (name == "hardware_device_id") {
723-
v_.hardware_device_id = static_cast<uint32_t>(JSON::Get<double>(value));
724-
} else if (name == "hardware_vendor_id") {
725-
v_.hardware_vendor_id = static_cast<uint32_t>(JSON::Get<double>(value));
726775
} else {
727776
throw JSON::unknown_value_error{};
728777
}
@@ -920,6 +969,69 @@ bool IsMultiProfileEnabled(const Config::SessionOptions& session_options) {
920969
return false;
921970
}
922971

972+
void SetDecoderProviderOptionsHardwareDeviceType(Config& config, std::string_view provider_name, std::string_view hardware_device_type) {
973+
auto normalized_provider = NormalizeProviderName(provider_name);
974+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
975+
if (provider_option.name == normalized_provider) {
976+
if (!provider_option.device_filtering_options) {
977+
provider_option.device_filtering_options = Config::DeviceFilteringOptions{};
978+
}
979+
provider_option.device_filtering_options->hardware_device_type = ParseHardwareDeviceType(hardware_device_type);
980+
}
981+
}
982+
}
983+
984+
void SetDecoderProviderOptionsHardwareDeviceId(Config& config, std::string_view provider_name, uint32_t hardware_device_id) {
985+
auto normalized_provider = NormalizeProviderName(provider_name);
986+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
987+
if (provider_option.name == normalized_provider) {
988+
if (!provider_option.device_filtering_options) {
989+
provider_option.device_filtering_options = Config::DeviceFilteringOptions{};
990+
}
991+
provider_option.device_filtering_options->hardware_device_id = hardware_device_id;
992+
}
993+
}
994+
}
995+
996+
void SetDecoderProviderOptionsHardwareVendorId(Config& config, std::string_view provider_name, uint32_t hardware_vendor_id) {
997+
auto normalized_provider = NormalizeProviderName(provider_name);
998+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
999+
if (provider_option.name == normalized_provider) {
1000+
if (!provider_option.device_filtering_options) {
1001+
provider_option.device_filtering_options = Config::DeviceFilteringOptions{};
1002+
}
1003+
provider_option.device_filtering_options->hardware_vendor_id = hardware_vendor_id;
1004+
}
1005+
}
1006+
}
1007+
1008+
void ClearDecoderProviderOptionsHardwareDeviceType(Config& config, std::string_view provider_name) {
1009+
auto normalized_provider = NormalizeProviderName(provider_name);
1010+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
1011+
if (provider_option.name == normalized_provider && provider_option.device_filtering_options) {
1012+
provider_option.device_filtering_options->hardware_device_type = std::nullopt;
1013+
}
1014+
}
1015+
}
1016+
1017+
void ClearDecoderProviderOptionsHardwareDeviceId(Config& config, std::string_view provider_name) {
1018+
auto normalized_provider = NormalizeProviderName(provider_name);
1019+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
1020+
if (provider_option.name == normalized_provider && provider_option.device_filtering_options) {
1021+
provider_option.device_filtering_options->hardware_device_id = std::nullopt;
1022+
}
1023+
}
1024+
}
1025+
1026+
void ClearDecoderProviderOptionsHardwareVendorId(Config& config, std::string_view provider_name) {
1027+
auto normalized_provider = NormalizeProviderName(provider_name);
1028+
for (auto& provider_option : config.model.decoder.session_options.provider_options) {
1029+
if (provider_option.name == normalized_provider && provider_option.device_filtering_options) {
1030+
provider_option.device_filtering_options->hardware_vendor_id = std::nullopt;
1031+
}
1032+
}
1033+
}
1034+
9231035
struct Root_Element : JSON::Element {
9241036
explicit Root_Element(Config& config) : config_{config} {}
9251037

src/config.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,16 @@ struct Config {
6464
fs::path config_path; // Path of the config directory
6565

6666
using NamedString = std::pair<std::string, std::string>;
67+
struct DeviceFilteringOptions {
68+
std::optional<OrtHardwareDeviceType> hardware_device_type; // OrtHardwareDeviceType_CPU, OrtHardwareDeviceType_GPU, OrtHardwareDeviceType_NPU
69+
std::optional<uint32_t> hardware_device_id;
70+
std::optional<uint32_t> hardware_vendor_id;
71+
};
72+
6773
struct ProviderOptions {
6874
std::string name;
6975
std::vector<NamedString> options;
76+
std::optional<DeviceFilteringOptions> device_filtering_options;
7077
};
7178

7279
struct SessionOptions {
@@ -247,11 +254,6 @@ struct Config {
247254

248255
} decoder;
249256

250-
// EP device filters
251-
std::optional<std::string> hardware_device_type; // CPU, GPU, NPU
252-
std::optional<uint32_t> hardware_device_id;
253-
std::optional<uint32_t> hardware_vendor_id;
254-
255257
} model;
256258

257259
struct Search {
@@ -290,4 +292,11 @@ void OverlayConfig(Config& config, std::string_view json);
290292
bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options);
291293
bool IsMultiProfileEnabled(const Config::SessionOptions& session_options);
292294

295+
void SetDecoderProviderOptionsHardwareDeviceType(Config& config, std::string_view provider_name, std::string_view hardware_device_type);
296+
void SetDecoderProviderOptionsHardwareDeviceId(Config& config, std::string_view provider_name, uint32_t hardware_device_id);
297+
void SetDecoderProviderOptionsHardwareVendorId(Config& config, std::string_view provider_name, uint32_t hardware_vendor_id);
298+
void ClearDecoderProviderOptionsHardwareDeviceType(Config& config, std::string_view provider_name);
299+
void ClearDecoderProviderOptionsHardwareDeviceId(Config& config, std::string_view provider_name);
300+
void ClearDecoderProviderOptionsHardwareVendorId(Config& config, std::string_view provider_name);
301+
293302
} // namespace Generators

src/csharp/Config.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,36 @@ public void RemoveModelData(string modelFilename)
4646
Result.VerifySuccess(NativeMethods.OgaConfigRemoveModelData(_configHandle, StringUtils.ToUtf8(modelFilename)));
4747
}
4848

49+
public void SetDecoderProviderOptionsHardwareDeviceType(string provider, string hardware_device_type)
50+
{
51+
Result.VerifySuccess(NativeMethods.OgaConfigSetDecoderProviderOptionsHardwareDeviceType(_configHandle, StringUtils.ToUtf8(provider), StringUtils.ToUtf8(hardware_device_type)));
52+
}
53+
54+
public void SetDecoderProviderOptionsHardwareDeviceId(string provider, uint hardware_device_id)
55+
{
56+
Result.VerifySuccess(NativeMethods.OgaConfigSetDecoderProviderOptionsHardwareDeviceId(_configHandle, StringUtils.ToUtf8(provider), hardware_device_id));
57+
}
58+
59+
public void SetDecoderProviderOptionsHardwareVendorId(string provider, uint hardware_vendor_id)
60+
{
61+
Result.VerifySuccess(NativeMethods.OgaConfigSetDecoderProviderOptionsHardwareVendorId(_configHandle, StringUtils.ToUtf8(provider), hardware_vendor_id));
62+
}
63+
64+
public void ClearDecoderProviderOptionsHardwareDeviceType(string provider)
65+
{
66+
Result.VerifySuccess(NativeMethods.OgaConfigClearDecoderProviderOptionsHardwareDeviceType(_configHandle, StringUtils.ToUtf8(provider)));
67+
}
68+
69+
public void ClearDecoderProviderOptionsHardwareDeviceId(string provider)
70+
{
71+
Result.VerifySuccess(NativeMethods.OgaConfigClearDecoderProviderOptionsHardwareDeviceId(_configHandle, StringUtils.ToUtf8(provider)));
72+
}
73+
74+
public void ClearDecoderProviderOptionsHardwareVendorId(string provider)
75+
{
76+
Result.VerifySuccess(NativeMethods.OgaConfigClearDecoderProviderOptionsHardwareVendorId(_configHandle, StringUtils.ToUtf8(provider)));
77+
}
78+
4979
~Config()
5080
{
5181
Dispose(false);

src/csharp/NativeMethods.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,24 @@ internal class NativeLib
6161
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
6262
public static extern IntPtr /* OgaResult* */ OgaConfigRemoveModelData(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ model_filename);
6363

64+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
65+
public static extern IntPtr /* OgaResult* */ OgaConfigSetDecoderProviderOptionsHardwareDeviceType(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name, byte[] /* const char* */ hardware_device_type);
66+
67+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
68+
public static extern IntPtr /* OgaResult* */ OgaConfigSetDecoderProviderOptionsHardwareDeviceId(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name, uint /* uint32_t */ hardware_device_id);
69+
70+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
71+
public static extern IntPtr /* OgaResult* */ OgaConfigSetDecoderProviderOptionsHardwareVendorId(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name, uint /* uint32_t */ hardware_vendor_id);
72+
73+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
74+
public static extern IntPtr /* OgaResult* */ OgaConfigClearDecoderProviderOptionsHardwareDeviceType(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name);
75+
76+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
77+
public static extern IntPtr /* OgaResult* */ OgaConfigClearDecoderProviderOptionsHardwareDeviceId(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name);
78+
79+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
80+
public static extern IntPtr /* OgaResult* */ OgaConfigClearDecoderProviderOptionsHardwareVendorId(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name);
81+
6482
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
6583
public static extern IntPtr /* OgaResult* */ OgaCreateModel(byte[] /* const char* */ configPath,
6684
out IntPtr /* OgaModel** */ model);

src/models/model.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,15 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
594594
}
595595

596596
#if USE_WINML
597-
// Get model device config
598-
std::optional<uint32_t> config_device_id = config.model.hardware_device_id;
599-
std::optional<uint32_t> config_vendor_id = config.model.hardware_vendor_id;
600-
std::optional<std::string> config_device_type = config.model.hardware_device_type;
597+
// Get device filtering config
598+
Config::DeviceFilteringOptions resolved_device_filtering;
599+
if (provider_options.device_filtering_options.has_value()) {
600+
resolved_device_filtering = provider_options.device_filtering_options.value();
601+
}
602+
603+
std::optional<uint32_t> config_device_id = resolved_device_filtering.hardware_device_id;
604+
std::optional<uint32_t> config_vendor_id = resolved_device_filtering.hardware_vendor_id;
605+
std::optional<OrtHardwareDeviceType> config_device_type_enum = resolved_device_filtering.hardware_device_type;
601606
// for OpenVINO, use "device_type" in provider_options exclusively if it's provided
602607
std::optional<std::string> config_ov_device_type = std::nullopt;
603608
if (provider_options.name == "OpenVINO") {
@@ -609,23 +614,11 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
609614
if (config_ov_device_type.has_value()) {
610615
config_device_id = std::nullopt;
611616
config_vendor_id = std::nullopt;
612-
config_device_type = std::nullopt;
613-
} else if (!(config_device_id.has_value() || config_vendor_id.has_value() || config_device_type.has_value())) {
617+
config_device_type_enum = std::nullopt;
618+
} else if (!(config_device_id.has_value() || config_vendor_id.has_value() || config_device_type_enum.has_value())) {
614619
config_ov_device_type = "CPU";
615620
}
616621
}
617-
std::optional<OrtHardwareDeviceType> config_device_type_enum;
618-
if (config_device_type.has_value()) {
619-
if (*config_device_type == "CPU") {
620-
config_device_type_enum = OrtHardwareDeviceType_CPU;
621-
} else if (*config_device_type == "GPU") {
622-
config_device_type_enum = OrtHardwareDeviceType_GPU;
623-
} else if (*config_device_type == "NPU") {
624-
config_device_type_enum = OrtHardwareDeviceType_NPU;
625-
} else {
626-
throw std::runtime_error("Unsupported hardware device type: " + *config_device_type);
627-
}
628-
}
629622

630623
// Match EP device with EP name in provider options and model device config
631624
// include\onnxruntime\core\graph\constants.h

src/ort_genai.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ struct OgaConfig : OgaAbstract {
197197
OgaCheckResult(OgaConfigRemoveModelData(this, model_filename.c_str()));
198198
}
199199

200+
void SetDecoderProviderOptionsHardwareDeviceType(const char* provider, const char* hardware_device_type) {
201+
OgaCheckResult(OgaConfigSetDecoderProviderOptionsHardwareDeviceType(this, provider, hardware_device_type));
202+
}
203+
204+
void SetDecoderProviderOptionsHardwareDeviceId(const char* provider, uint32_t hardware_device_id) {
205+
OgaCheckResult(OgaConfigSetDecoderProviderOptionsHardwareDeviceId(this, provider, hardware_device_id));
206+
}
207+
208+
void SetDecoderProviderOptionsHardwareVendorId(const char* provider, uint32_t hardware_vendor_id) {
209+
OgaCheckResult(OgaConfigSetDecoderProviderOptionsHardwareVendorId(this, provider, hardware_vendor_id));
210+
}
211+
212+
void ClearDecoderProviderOptionsHardwareDeviceType(const char* provider) {
213+
OgaCheckResult(OgaConfigClearDecoderProviderOptionsHardwareDeviceType(this, provider));
214+
}
215+
216+
void ClearDecoderProviderOptionsHardwareDeviceId(const char* provider) {
217+
OgaCheckResult(OgaConfigClearDecoderProviderOptionsHardwareDeviceId(this, provider));
218+
}
219+
220+
void ClearDecoderProviderOptionsHardwareVendorId(const char* provider) {
221+
OgaCheckResult(OgaConfigClearDecoderProviderOptionsHardwareVendorId(this, provider));
222+
}
223+
200224
static void operator delete(void* p) { OgaDestroyConfig(reinterpret_cast<OgaConfig*>(p)); }
201225
};
202226

0 commit comments

Comments
 (0)