From 283bd5b190e774cb45afa841b78ac0852493953f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 15 Sep 2025 17:10:03 -0700 Subject: [PATCH 01/11] First draft (untested) --- .../onnxruntime/core/session/environment.h | 1 + .../core/session/onnxruntime_ep_c_api.h | 16 ++++ onnxruntime/core/session/environment.cc | 85 ++++++++++++++++++- onnxruntime/core/session/plugin_ep/ep_api.cc | 26 ++++++ onnxruntime/core/session/plugin_ep/ep_api.h | 7 ++ 5 files changed, 134 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 59ca1a1df762e..51b13a26c5caa 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -220,6 +220,7 @@ class Environment { ~EpInfo(); std::unique_ptr library; + std::vector> additional_hw_devices; std::vector> execution_devices; std::vector factories; std::vector internal_factories; // factories that can create IExecutionProvider instances diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 975f6b453a88d..be9e21e625aef 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -465,6 +465,15 @@ struct OrtEpApi { */ ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); + + ORT_API2_STATUS(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device); + + ORT_CLASS_RELEASE(HardwareDevice); }; /** @@ -981,6 +990,13 @@ struct OrtEpFactory { _In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStreamImpl** stream); + + ORT_API2_STATUS(GetAdditionalHardwareDevices, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* found_devices, + _In_ size_t num_found_devices, + _Inout_ OrtHardwareDevice** additional_devices, + _In_ size_t max_additional_devices, + _Out_ size_t* num_additional_devices); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 9c40eb75780ee..89f8c2bdc8825 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -746,6 +746,63 @@ std::vector SortDevicesByType() { return sorted_devices; } + +std::vector FilterEpHardwareDevices(const OrtEpFactory& ep_factory, + gsl::span ort_hw_devices, + gsl::span ep_hw_devices, + const char* lib_registration_name) { + // ORT is not required to use all hw devices provided by the EP factory. + // This function filters out the following hw devices: + // - HW devices that were already found during ORT's device discovery. + // - HW devices with vendor information that does not match the EP factory. + + if (ep_hw_devices.empty()) { + return {}; + } + + auto have_ort_hw_device = [&ort_hw_devices](const OrtHardwareDevice* candidate) -> bool { + return std::find_if(ort_hw_devices.begin(), ort_hw_devices.end(), + [&candidate](const OrtHardwareDevice* ort_device) { + return candidate->device_id == ort_device->device_id && + candidate->vendor_id == ort_device->vendor_id && + candidate->type == ort_device->type; + }) != ort_hw_devices.end(); + }; + + std::vector result; + result.reserve(ep_hw_devices.size()); + + const char* ep_factory_name = ep_factory.GetName(&ep_factory); + const uint32_t ep_vendor_id = ep_factory.GetVendorId(&ep_factory); + const std::string ep_vendor = ep_factory.GetVendor(&ep_factory); + + for (OrtHardwareDevice* candidate : ep_hw_devices) { + if (candidate == nullptr) { + continue; // EP library provided a NULL hw device. Skip it. + } + + if (candidate->vendor_id != ep_vendor_id || + candidate->vendor != ep_vendor) { + LOGS_DEFAULT(WARNING) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" + << ep_factory_name << "' attempted to register a OrtHardwareDevice with non-matching " + << "vendor information. Expected " << ep_vendor << "(" << ep_vendor_id << ") but got " + << candidate->vendor << "(" << candidate->vendor_id << ")."; + continue; + } + + if (have_ort_hw_device(candidate)) { + LOGS_DEFAULT(VERBOSE) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" + << ep_factory_name << "' attempted to register a OrtHardwareDevice that has already been " + << "found by ONNX Runtime. OrtHardwareDevice info: vendor_id=" << ep_vendor_id + << ", device_id=" << candidate->device_id << ", type=" << candidate->type; + continue; + } + + result.push_back(candidate); + } + + return result; +} } // namespace Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, @@ -772,10 +829,36 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u auto& factory = *factory_ptr; + // Allow EP factory to provide additional OrtHardwareDevice instances to: + // - Support offline/off-target model compilation. EP may provide a virtual OrtHardwareDevice that represents the + // compilation target. + // - Enable EP library to provide hardware devices not discovered by ORT. + std::array ep_hw_devices{nullptr}; + size_t num_ep_hw_devices = 0; + + if (factory.GetAdditionalHardwareDevices != nullptr) { + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + factory.GetAdditionalHardwareDevices(&factory, sorted_devices.data(), sorted_devices.size(), + ep_hw_devices.data(), ep_hw_devices.size(), &num_ep_hw_devices))); + } + + std::vector all_hw_devices = sorted_devices; + + if (num_ep_hw_devices > 0) { + std::vector valid_hw_devices = FilterEpHardwareDevices( + factory, sorted_devices, gsl::span(ep_hw_devices.data(), num_ep_hw_devices), + instance.library->RegistrationName()); + + for (OrtHardwareDevice* ep_hw_device : valid_hw_devices) { + instance.additional_hw_devices.emplace_back(ep_hw_device); // take ownership + all_hw_devices.push_back(ep_hw_device); // Add EP-specific HW devices to the end + } + } + std::array ep_devices{nullptr}; size_t num_ep_devices = 0; ORT_RETURN_IF_ERROR(ToStatusAndRelease( - factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(), + factory.GetSupportedDevices(&factory, all_hw_devices.data(), all_hw_devices.size(), ep_devices.data(), ep_devices.size(), &num_ep_devices))); for (size_t i = 0; i < num_ep_devices; ++i) { diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index cae0b086af66c..9d1b668dde962 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -205,6 +205,32 @@ ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* pr return id; } +ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device) { + API_IMPL_BEGIN + auto device = std::make_unique(); + device->type = type; + device->vendor_id = vendor_id; + device->device_id = device_id; + device->vendor = std::string(vendor_name); + + if (metadata) { + device->metadata = *metadata; + } + + *hardware_device = device.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device) { + delete device; +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index c0dc79f3fb333..129230be4f618 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -40,4 +40,11 @@ ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); +ORT_API_STATUS_IMPL(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, + _In_ uint32_t vendor_id, + _In_ uint32_t device_id, + _In_ const char* vendor_name, + _In_opt_ const OrtKeyValuePairs* metadata, + _Out_ OrtHardwareDevice** hardware_device); +ORT_API(void, ReleaseHardwareDevice, _Frees_ptr_opt_ OrtHardwareDevice* device); } // namespace OrtExecutionProviderApi From bc4df99da5118b49e6a24794918cd972b66e71ee Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 16 Sep 2025 14:30:01 -0700 Subject: [PATCH 02/11] Properly discard invalid hw devices provided by the EP; Add hw device metadata entry to indicate who created the device instance --- onnxruntime/core/platform/device_discovery.h | 2 ++ .../core/platform/windows/device_discovery.cc | 2 ++ onnxruntime/core/session/environment.cc | 33 +++++++++++-------- onnxruntime/core/session/plugin_ep/ep_api.cc | 3 ++ 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/platform/device_discovery.h b/onnxruntime/core/platform/device_discovery.h index b49e63b90236a..b90dfbc8702b3 100644 --- a/onnxruntime/core/platform/device_discovery.h +++ b/onnxruntime/core/platform/device_discovery.h @@ -9,6 +9,8 @@ namespace onnxruntime { +constexpr const char* kHardwareDeviceKey_DiscoveredBy = "DiscoveredBy"; + class DeviceDiscovery { public: static const std::unordered_set& GetDevices(); diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index cf761f587ad0b..7b8498c6e7d3c 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -618,6 +618,8 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor std::unordered_map* extra_metadata = nullptr) { OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)}; + ortdevice.metadata.Add(kHardwareDeviceKey_DiscoveredBy, "ONNX Runtime"); + if (!device.description.empty()) { ortdevice.metadata.Add("Description", to_safe_string(device.description)); } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 89f8c2bdc8825..ad8fcadd81d3c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -727,9 +727,9 @@ Status Environment::ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDevi } namespace { -std::vector SortDevicesByType() { +InlinedVector SortDevicesByType() { auto& devices = DeviceDiscovery::GetDevices(); - std::vector sorted_devices; + InlinedVector sorted_devices; sorted_devices.reserve(devices.size()); const auto select_by_type = [&](OrtHardwareDeviceType type) { @@ -747,10 +747,11 @@ std::vector SortDevicesByType() { return sorted_devices; } -std::vector FilterEpHardwareDevices(const OrtEpFactory& ep_factory, - gsl::span ort_hw_devices, - gsl::span ep_hw_devices, - const char* lib_registration_name) { +InlinedVector> FilterEpHardwareDevices( + const OrtEpFactory& ep_factory, + gsl::span ort_hw_devices, + gsl::span ep_hw_devices, + const char* lib_registration_name) { // ORT is not required to use all hw devices provided by the EP factory. // This function filters out the following hw devices: // - HW devices that were already found during ORT's device discovery. @@ -769,7 +770,8 @@ std::vector FilterEpHardwareDevices(const OrtEpFactory& ep_f }) != ort_hw_devices.end(); }; - std::vector result; + InlinedVector> devices_to_discard; + InlinedVector> result; result.reserve(ep_hw_devices.size()); const char* ep_factory_name = ep_factory.GetName(&ep_factory); @@ -787,6 +789,7 @@ std::vector FilterEpHardwareDevices(const OrtEpFactory& ep_f << ep_factory_name << "' attempted to register a OrtHardwareDevice with non-matching " << "vendor information. Expected " << ep_vendor << "(" << ep_vendor_id << ") but got " << candidate->vendor << "(" << candidate->vendor_id << ")."; + devices_to_discard.emplace_back(candidate); // take ownership to discard on function return continue; } @@ -795,10 +798,12 @@ std::vector FilterEpHardwareDevices(const OrtEpFactory& ep_f << ep_factory_name << "' attempted to register a OrtHardwareDevice that has already been " << "found by ONNX Runtime. OrtHardwareDevice info: vendor_id=" << ep_vendor_id << ", device_id=" << candidate->device_id << ", type=" << candidate->type; + devices_to_discard.emplace_back(candidate); // take ownership to discard on function return continue; } - result.push_back(candidate); + candidate->metadata.Add(kHardwareDeviceKey_DiscoveredBy, ep_factory_name); + result.emplace_back(candidate); } return result; @@ -821,7 +826,7 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. // the set of hardware devices is static so this can also be static. - const static std::vector sorted_devices = SortDevicesByType(); + const static InlinedVector sorted_devices = SortDevicesByType(); for (auto* factory_ptr : instance.factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", @@ -842,16 +847,16 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u ep_hw_devices.data(), ep_hw_devices.size(), &num_ep_hw_devices))); } - std::vector all_hw_devices = sorted_devices; + InlinedVector all_hw_devices = sorted_devices; if (num_ep_hw_devices > 0) { - std::vector valid_hw_devices = FilterEpHardwareDevices( + InlinedVector> valid_hw_devices = FilterEpHardwareDevices( factory, sorted_devices, gsl::span(ep_hw_devices.data(), num_ep_hw_devices), instance.library->RegistrationName()); - for (OrtHardwareDevice* ep_hw_device : valid_hw_devices) { - instance.additional_hw_devices.emplace_back(ep_hw_device); // take ownership - all_hw_devices.push_back(ep_hw_device); // Add EP-specific HW devices to the end + for (std::unique_ptr& ep_hw_device : valid_hw_devices) { + all_hw_devices.push_back(ep_hw_device.get()); // Add EP-specific HW devices to the end + instance.additional_hw_devices.push_back(std::move(ep_hw_device)); // take ownership } } diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 9d1b668dde962..bca856efb5446 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -256,6 +256,9 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::SyncStream_GetImpl, &OrtExecutionProviderApi::SyncStream_GetSyncId, &OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream, + + &OrtExecutionProviderApi::CreateHardwareDevice, + &OrtExecutionProviderApi::ReleaseHardwareDevice, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned From fb7441ca70b22debb2cd03fa274e009e6dc2336f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 16 Sep 2025 14:56:56 -0700 Subject: [PATCH 03/11] Add example code to example EP factory --- onnxruntime/test/autoep/library/ep_factory.cc | 57 +++++++++++++++++++ onnxruntime/test/autoep/library/ep_factory.h | 7 +++ 2 files changed, 64 insertions(+) diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 4da7d722a5e0b..85b1977667359 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -19,6 +19,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; + GetAdditionalHardwareDevices = GetAdditionalHardwareDevicesImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -97,6 +98,62 @@ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* th return factory->ep_version_.c_str(); } +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* found_devices, + size_t num_found_devices, + OrtHardwareDevice** additional_devices, + size_t max_additional_devices, + size_t* num_additional_devices) noexcept { + // EP factory can provide ORT with additional hardware devices that ORT did not find, or more likely, that are not + // available on the target machine but could serve as compilation targets. + + // As an example, this example EP factory will first look for a GPU device among the devices found by ORT. If there + // is no GPU available, then this EP will create a virtual GPU device that the application can use a compilation target. + + auto* factory = static_cast(this_ptr); + bool found_gpu = false; + + for (size_t i = 0; i < num_found_devices; ++i) { + const OrtHardwareDevice& device = *found_devices[i]; + + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + found_gpu = true; + break; + } + } + + *num_additional_devices = 0; + + if (!found_gpu && max_additional_devices >= 1) { + // Add some hw metadata for this GPU + OrtKeyValuePairs* hw_metadata = nullptr; + factory->ort_api.CreateKeyValuePairs(&hw_metadata); + factory->ort_api.AddKeyValuePair(hw_metadata, "Discrete", "1"); + factory->ort_api.AddKeyValuePair(hw_metadata, "CompileTargetOnly", "1"); + + // Create a new HW device. Must have the same vendor information as this factory. Otherwise, ORT will not use it. + OrtHardwareDevice* new_device = nullptr; + auto* status = factory->ort_api.GetEpApi()->CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, + factory->vendor_id_, + /*device_id*/ 0, + factory->vendor_.c_str(), + hw_metadata, + &new_device); + factory->ort_api.ReleaseKeyValuePairs(hw_metadata); // Release since ORT makes a copy. + + if (status != nullptr) { + return status; + } + + // ORT will take ownership of the new HW device. + additional_devices[0] = new_device; + *num_additional_devices = 1; + } + + return nullptr; +} + /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 088deda1fe9d2..2d6146b753255 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -33,6 +33,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetAdditionalHardwareDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* found_devices, + size_t num_found_devices, + OrtHardwareDevice** additional_devices, + size_t max_additional_devices, + size_t* num_additional_devices) noexcept; + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, From d23bd3af69b74d3e1c9fed18e2c2afb628420b2e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 8 Oct 2025 13:24:27 -0700 Subject: [PATCH 04/11] Stub out default implementation of new factory function for internal eps --- .../core/session/plugin_ep/ep_factory_internal.cc | 4 +++- .../core/session/plugin_ep/ep_factory_internal.h | 10 ++++++++++ .../core/session/plugin_ep/ep_factory_internal_impl.h | 11 +++++++++++ .../core/session/plugin_ep/forward_to_factory_impl.h | 11 +++++++++++ onnxruntime/test/autoep/library/ep_factory.cc | 2 +- 5 files changed, 36 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index f3e30caf07e81..053e6ed0a621c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -13,7 +13,8 @@ namespace onnxruntime { using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) - : impl_{std::move(impl)} { + : OrtEpFactory{}, // Ensure optional functions are default initialized to nullptr + impl_{std::move(impl)} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; @@ -29,6 +30,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; + OrtEpFactory::GetAdditionalHardwareDevices = Forward::GetAdditionalHardwareDevices; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 093bfce462d32..c794230dff789 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -87,6 +87,16 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } + OrtStatus* GetAdditionalHardwareDevices( + _In_reads_(num_found_devices) const OrtHardwareDevice* const* found_devices, + _In_ size_t num_found_devices, + _Inout_ OrtHardwareDevice** additional_devices, + _In_ size_t max_additional_devices, + _Out_ size_t* num_additional_devices) noexcept { + return impl_->GetAdditionalHardwareDevices(found_devices, num_found_devices, additional_devices, + max_additional_devices, num_additional_devices); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index f29154d19c53c..0888c18d5a2f0 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -83,6 +83,17 @@ class EpFactoryInternalImpl { "CreateSyncStreamForDevice is not implemented for this EP factory."); } + virtual OrtStatus* GetAdditionalHardwareDevices( + _In_reads_(num_found_devices) const OrtHardwareDevice* const* /*found_devices*/, + _In_ size_t /*num_found_devices*/, + _Inout_ OrtHardwareDevice** /*additional_devices*/, + _In_ size_t /*max_additional_devices*/, + _Out_ size_t* num_additional_devices) noexcept { + // Default implementation does not return any additional hw devices. + *num_additional_devices = 0; + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 2cceb1d08d536..54db086e8c10a 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -82,6 +82,17 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream); } + static OrtStatus* ORT_API_CALL GetAdditionalHardwareDevices(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* found_devices, + size_t num_found_devices, + OrtHardwareDevice** additional_devices, + size_t max_additional_devices, + size_t* num_additional_devices) noexcept { + return static_cast(this_ptr)->GetAdditionalHardwareDevices(found_devices, num_found_devices, + additional_devices, max_additional_devices, + num_additional_devices); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 85b1977667359..545dc7aac79a9 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -12,7 +12,7 @@ #include "ep_stream_support.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger) - : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + : OrtEpFactory{}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; From f86a4e915ab9aecd83f3657267ef61e94ce83063 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 8 Oct 2025 20:51:00 -0700 Subject: [PATCH 05/11] Check for vendor too --- onnxruntime/test/autoep/library/ep_factory.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 545dc7aac79a9..1eaa7b244ae23 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -117,7 +117,8 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEp for (size_t i = 0; i < num_found_devices; ++i) { const OrtHardwareDevice& device = *found_devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory->ort_api.HardwareDevice_Vendor(&device) == factory->vendor_) { found_gpu = true; break; } From f711df727af10d6528a1515cfc86a661ba2e5085 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 00:08:07 -0700 Subject: [PATCH 06/11] Formalize hardware metadata keys --- .../core/session/onnxruntime_ep_c_api.h | 46 ++++++++++++++++++- .../onnxruntime_ep_device_ep_metadata_keys.h | 21 ++++++++- .../core/platform/apple/device_discovery.cc | 5 ++ onnxruntime/core/platform/device_discovery.h | 2 - .../core/platform/device_discovery_common.cc | 3 ++ .../core/platform/linux/device_discovery.cc | 5 ++ .../core/platform/windows/device_discovery.cc | 4 +- onnxruntime/core/session/environment.cc | 32 +++++++------ .../core/session/provider_policy_context.cc | 3 +- onnxruntime/test/autoep/library/ep_factory.cc | 22 +++++---- onnxruntime/test/autoep/test_selection.cc | 20 ++++++++ 11 files changed, 134 insertions(+), 29 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index be9e21e625aef..427de5e404ae7 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -466,6 +466,24 @@ struct OrtEpApi { ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); + /** \brief Create an OrtHardwareDevice. + * + * \note Called within OrtEpFactory::GetAdditionalHardwareDevices to augment the list of devices discovered by ORT. + * + * \param[in] type The hardware device type. + * \param[in] vendor_id The hardware device's vendor identifier. + * \param[in] device_id The hardware device's identifier. + * \param[in] vendor_name The hardware device's vendor name as a null-terminated string. Copied by ORT. + * \param[in] metadata Optional OrtKeyValuePairs instance for hardware device metadata that may be queried by + * applications via OrtApi::GetEpDevices() or the EP factory that receives this hardware device + * instance as input to OrtEpFactory::GetSupportedDevices(). + * Refer to onnxruntime_ep_device_ep_metadata_keys.h for common OrtHardwareDevice metadata keys. + * \param[out] hardware_device Output parameter set to the new OrtHardwareDevice instance that is created. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ ORT_API2_STATUS(CreateHardwareDevice, _In_ OrtHardwareDeviceType type, _In_ uint32_t vendor_id, _In_ uint32_t device_id, @@ -991,8 +1009,34 @@ struct OrtEpFactory { _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStreamImpl** stream); + /** \brief Get additional hardware devices from the execution provider to augment the devices discovered by ORT. + * + * \note Any returned devices that have already been found by ORT are ignored. + * + * \note New additional devices created by this EP factory are not provided to other EP factories. Only this + * EP factory receives the new additional hardware devices via OrtEpFactory::GetSupportedDevices(). + * Any OrtEpDevice instances that this EP factory creates with an additional hardware device are visible to + * applications that call OrtApi::GetEpDevices(). + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] found_devices Array of hardware devices that have already been found by ORT during device discovery. + * \param[in] num_found_devices Number of hardware devices that have already been found by ORT. + * \param[out] additional_devices Additional OrtHardwareDevice instances that the EP can use. + * The implementation should call OrtEpApi::CreateHardwareDevice to create the devices, + * and then add the new OrtHardwareDevice instances to this pre-allocated array. + * ORT will take ownership of the values returned. i.e. usage is: + * `additional_devices[0] = ;` + * \param[in] max_additional_devices The maximum number of OrtHardwareDevice instances that can be added to + * `additional_devices`. Current default is 8. This can be increased if needed. + * \param[out] num_additional_devices The number of additional hardware devices actually added + * to `additional_devices`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ ORT_API2_STATUS(GetAdditionalHardwareDevices, _In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* found_devices, + _In_reads_(num_found_devices) const OrtHardwareDevice* const* found_devices, _In_ size_t num_found_devices, _Inout_ OrtHardwareDevice** additional_devices, _In_ size_t max_additional_devices, diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index bbd6a43bb7a41..8f86d62bafa3b 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -3,7 +3,7 @@ #pragma once -// This file contains well-known keys for OrtEpDevice EP metadata entries. +// This file contains well-known keys for OrtEpDevice and OrtHardwareDevice metadata entries. // It does NOT specify all available metadata keys. // Key for the execution provider version string. This should be available for all plugin EPs. @@ -16,3 +16,22 @@ static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compa // Key for the execution provider library path (for dynamically loaded EPs) static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; + +// Key to retrieve the identity of the entity that discovered and initialized the OrtHardwareDevice. +// Possible values: +// - "ONNX Runtime" (devices discovered by ONNX Runtime). +// - (devices discovered by a plugin EP library registered with the OrtEnv). +static const char* kOrtHardwareDevice_MetadataKey_DiscoveredBy = "DiscoveredBy"; + +// Key to determine if a OrtHardwareDevice represents a virtual (non-hardware) device. +// Possible values: +// - "0": OrtHardwareDevice is not virtual; represents an actual hardware device. +// - "1": OrtHardwareDevice is virtual. +static const char* kOrtHardwareDevice_MetadataKey_IsVirtual = "IsVirtual"; + +// Key to determine if a OrtHardwareDevice represents a discrete hardware device, for example, +// a discrete GPU vs an integrated GPU. +// Possible values: +// - "0": Not discrete +// - "1": Discrete +static const char* kOrtHardwareDevice_MetadataKey_IsDiscrete = "Discrete"; diff --git a/onnxruntime/core/platform/apple/device_discovery.cc b/onnxruntime/core/platform/apple/device_discovery.cc index 767b834e38756..c7ea26b7d3c00 100644 --- a/onnxruntime/core/platform/apple/device_discovery.cc +++ b/onnxruntime/core/platform/apple/device_discovery.cc @@ -7,6 +7,7 @@ #include #include "core/common/logging/logging.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { @@ -27,6 +28,8 @@ std::vector GetGpuDevices() { gpu_device.type = OrtHardwareDeviceType_GPU; gpu_device.vendor_id = kApplePciVendorId; gpu_device.vendor = kAppleVendorName; + gpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, "ONNX Runtime"); + gpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_IsVirtual, "0"); result.emplace_back(std::move(gpu_device)); } @@ -74,6 +77,8 @@ std::vector GetNpuDevices() { npu_device.type = OrtHardwareDeviceType_NPU; npu_device.vendor_id = kApplePciVendorId; npu_device.vendor = kAppleVendorName; + npu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, "ONNX Runtime"); + npu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_IsVirtual, "0"); result.emplace_back(std::move(npu_device)); } diff --git a/onnxruntime/core/platform/device_discovery.h b/onnxruntime/core/platform/device_discovery.h index b90dfbc8702b3..b49e63b90236a 100644 --- a/onnxruntime/core/platform/device_discovery.h +++ b/onnxruntime/core/platform/device_discovery.h @@ -9,8 +9,6 @@ namespace onnxruntime { -constexpr const char* kHardwareDeviceKey_DiscoveredBy = "DiscoveredBy"; - class DeviceDiscovery { public: static const std::unordered_set& GetDevices(); diff --git a/onnxruntime/core/platform/device_discovery_common.cc b/onnxruntime/core/platform/device_discovery_common.cc index dcba31aed6fec..61d5d2d560367 100644 --- a/onnxruntime/core/platform/device_discovery_common.cc +++ b/onnxruntime/core/platform/device_discovery_common.cc @@ -9,6 +9,7 @@ #include "core/common/cpuid_info.h" #include "core/common/logging/logging.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { @@ -48,6 +49,8 @@ OrtHardwareDevice DeviceDiscovery::GetCpuDeviceFromCPUIDInfo() { cpu_device.vendor_id = cpuid_info.GetCPUVendorId(); cpu_device.device_id = 0; cpu_device.type = OrtHardwareDeviceType_CPU; + cpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, "ONNX Runtime"); + cpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_IsVirtual, "0"); return cpu_device; } diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index 6a02a1b46028f..9eb4209b172ba 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -12,6 +12,7 @@ #include "core/common/logging/logging.h" #include "core/common/parse_string.h" #include "core/common/string_utils.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace fs = std::filesystem; @@ -138,6 +139,10 @@ Status GetGpuDevices(std::vector& gpu_devices_out) { for (const auto& gpu_sysfs_path_info : gpu_sysfs_path_infos) { OrtHardwareDevice gpu_device{}; ORT_RETURN_IF_ERROR(GetGpuDeviceFromSysfs(gpu_sysfs_path_info, gpu_device)); + + gpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, "ONNX Runtime"); + gpu_device.metadata.Add(kOrtHardwareDevice_MetadataKey_IsVirtual, "0"); + gpu_devices.emplace_back(std::move(gpu_device)); } diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 7b8498c6e7d3c..e1f7d20db6c76 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -14,6 +14,7 @@ #include "core/common/logging/logging.h" #include "core/platform/env.h" #include "core/session/abi_devices.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" //// For SetupApi info #include @@ -618,7 +619,8 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor std::unordered_map* extra_metadata = nullptr) { OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)}; - ortdevice.metadata.Add(kHardwareDeviceKey_DiscoveredBy, "ONNX Runtime"); + ortdevice.metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, "ONNX Runtime"); + ortdevice.metadata.Add(kOrtHardwareDevice_MetadataKey_IsVirtual, "0"); if (!device.description.empty()) { ortdevice.metadata.Add("Description", to_safe_string(device.description)); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index ad8fcadd81d3c..699a7f392cfef 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,6 +16,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -776,33 +777,38 @@ InlinedVector> FilterEpHardwareDevices( const char* ep_factory_name = ep_factory.GetName(&ep_factory); const uint32_t ep_vendor_id = ep_factory.GetVendorId(&ep_factory); - const std::string ep_vendor = ep_factory.GetVendor(&ep_factory); for (OrtHardwareDevice* candidate : ep_hw_devices) { if (candidate == nullptr) { continue; // EP library provided a NULL hw device. Skip it. } - if (candidate->vendor_id != ep_vendor_id || - candidate->vendor != ep_vendor) { - LOGS_DEFAULT(WARNING) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" - << ep_factory_name << "' attempted to register a OrtHardwareDevice with non-matching " - << "vendor information. Expected " << ep_vendor << "(" << ep_vendor_id << ") but got " - << candidate->vendor << "(" << candidate->vendor_id << ")."; - devices_to_discard.emplace_back(candidate); // take ownership to discard on function return - continue; - } - if (have_ort_hw_device(candidate)) { LOGS_DEFAULT(VERBOSE) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" << ep_factory_name << "' attempted to register a OrtHardwareDevice that has already been " << "found by ONNX Runtime. OrtHardwareDevice info: vendor_id=" << ep_vendor_id - << ", device_id=" << candidate->device_id << ", type=" << candidate->type; + << ", device_id=" << candidate->device_id << ", type=" << candidate->type << ". " + << "ORT will not use this device."; devices_to_discard.emplace_back(candidate); // take ownership to discard on function return continue; } - candidate->metadata.Add(kHardwareDeviceKey_DiscoveredBy, ep_factory_name); + const std::map& metadata = candidate->metadata.Entries(); + + // Always set the "DiscoveredBy" metadata entry to the EP name. + if (auto discovered_by_iter = metadata.find(kOrtHardwareDevice_MetadataKey_DiscoveredBy); + discovered_by_iter == metadata.end()) { + candidate->metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, ep_factory_name); + } else if (discovered_by_iter->second != ep_factory_name) { + LOGS_DEFAULT(WARNING) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" + << ep_factory_name << "' attempted to register a OrtHardwareDevice with an invalid entry " + << "for metadata key '" << kOrtHardwareDevice_MetadataKey_DiscoveredBy << "'. " + << "Expected '" << ep_factory_name << "' but got '" << discovered_by_iter->second << "'. " + << "ORT will use the device but overwrite the metadata entry to the expected value."; + + candidate->metadata.Add(kOrtHardwareDevice_MetadataKey_DiscoveredBy, ep_factory_name); // overwrite + } + result.emplace_back(candidate); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 6bcbda0f13b92..ad847a3af2590 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -16,6 +16,7 @@ #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" @@ -38,7 +39,7 @@ bool IsDiscreteDevice(const OrtEpDevice* d) { } const auto& entries = d->device->metadata.Entries(); - if (auto it = entries.find("Discrete"); it != entries.end()) { + if (auto it = entries.find(kOrtHardwareDevice_MetadataKey_IsDiscrete); it != entries.end()) { return it->second == "1"; } diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 1eaa7b244ae23..968c6b68d5140 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -5,6 +5,8 @@ #include +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + #include "ep.h" #include "ep_allocator.h" #include "ep_arena.h" @@ -127,20 +129,20 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEp *num_additional_devices = 0; if (!found_gpu && max_additional_devices >= 1) { - // Add some hw metadata for this GPU + // Create a new HW device. OrtKeyValuePairs* hw_metadata = nullptr; factory->ort_api.CreateKeyValuePairs(&hw_metadata); - factory->ort_api.AddKeyValuePair(hw_metadata, "Discrete", "1"); - factory->ort_api.AddKeyValuePair(hw_metadata, "CompileTargetOnly", "1"); + factory->ort_api.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_DiscoveredBy, + factory->ep_name_.c_str()); + factory->ort_api.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_IsVirtual, "1"); - // Create a new HW device. Must have the same vendor information as this factory. Otherwise, ORT will not use it. OrtHardwareDevice* new_device = nullptr; - auto* status = factory->ort_api.GetEpApi()->CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, - factory->vendor_id_, - /*device_id*/ 0, - factory->vendor_.c_str(), - hw_metadata, - &new_device); + auto* status = factory->ep_api.CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, + factory->vendor_id_, + /*device_id*/ 0, + factory->vendor_.c_str(), + hw_metadata, + &new_device); factory->ort_api.ReleaseKeyValuePairs(hw_metadata); // Release since ORT makes a copy. if (status != nullptr) { diff --git a/onnxruntime/test/autoep/test_selection.cc b/onnxruntime/test/autoep/test_selection.cc index 72f39be917f90..dc3eda7a9c69a 100644 --- a/onnxruntime/test/autoep/test_selection.cc +++ b/onnxruntime/test/autoep/test_selection.cc @@ -13,6 +13,7 @@ #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" #include "test/autoep/test_autoep_utils.h" @@ -35,6 +36,17 @@ void DefaultDeviceSelection(const std::string& ep_name, std::vectorEpDevice_EpName(device) == ep_name) { + const auto* hw_device = c_api->EpDevice_Device(device); + const OrtKeyValuePairs* hw_kvps = c_api->HardwareDevice_Metadata(hw_device); + + const char* discovered_by = c_api->GetKeyValue(hw_kvps, kOrtHardwareDevice_MetadataKey_DiscoveredBy); + ASSERT_NE(discovered_by, nullptr); + ASSERT_STREQ(discovered_by, "ONNX Runtime"); + + const char* is_virtual = c_api->GetKeyValue(hw_kvps, kOrtHardwareDevice_MetadataKey_IsVirtual); + ASSERT_NE(is_virtual, nullptr); + ASSERT_STREQ(is_virtual, "0"); + devices.push_back(device); break; } @@ -193,6 +205,14 @@ TEST(AutoEpSelection, DmlEP) { const auto* device = c_api->EpDevice_Device(ep_device); const OrtKeyValuePairs* kvps = c_api->HardwareDevice_Metadata(device); + const char* discovered_by = c_api->GetKeyValue(kvps, kOrtHardwareDevice_MetadataKey_DiscoveredBy); + ASSERT_NE(discovered_by, nullptr); + ASSERT_STREQ(discovered_by, "ONNX Runtime"); + + const char* is_virtual = c_api->GetKeyValue(kvps, kOrtHardwareDevice_MetadataKey_IsVirtual); + ASSERT_NE(is_virtual, nullptr); + ASSERT_STREQ(is_virtual, "0"); + if (devices.empty()) { // add the first device devices.push_back(ep_device); From 02ced2470cfcf1cd64459c1083578dcacdca216b Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 00:38:41 -0700 Subject: [PATCH 07/11] Remove discrete key --- .../core/session/onnxruntime_ep_device_ep_metadata_keys.h | 7 ------- onnxruntime/core/session/environment.cc | 5 ++--- onnxruntime/core/session/provider_policy_context.cc | 3 +-- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 8f86d62bafa3b..b4940a06bb784 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -28,10 +28,3 @@ static const char* kOrtHardwareDevice_MetadataKey_DiscoveredBy = "DiscoveredBy"; // - "0": OrtHardwareDevice is not virtual; represents an actual hardware device. // - "1": OrtHardwareDevice is virtual. static const char* kOrtHardwareDevice_MetadataKey_IsVirtual = "IsVirtual"; - -// Key to determine if a OrtHardwareDevice represents a discrete hardware device, for example, -// a discrete GPU vs an integrated GPU. -// Possible values: -// - "0": Not discrete -// - "1": Discrete -static const char* kOrtHardwareDevice_MetadataKey_IsDiscrete = "Discrete"; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 699a7f392cfef..bacfe62ae1e05 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -756,7 +756,6 @@ InlinedVector> FilterEpHardwareDevices( // ORT is not required to use all hw devices provided by the EP factory. // This function filters out the following hw devices: // - HW devices that were already found during ORT's device discovery. - // - HW devices with vendor information that does not match the EP factory. if (ep_hw_devices.empty()) { return {}; @@ -783,6 +782,7 @@ InlinedVector> FilterEpHardwareDevices( continue; // EP library provided a NULL hw device. Skip it. } + // Skip hw device already found by ORT. if (have_ort_hw_device(candidate)) { LOGS_DEFAULT(VERBOSE) << "EP library registered under '" << lib_registration_name << "' with OrtEpFactory '" << ep_factory_name << "' attempted to register a OrtHardwareDevice that has already been " @@ -841,9 +841,8 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u auto& factory = *factory_ptr; // Allow EP factory to provide additional OrtHardwareDevice instances to: - // - Support offline/off-target model compilation. EP may provide a virtual OrtHardwareDevice that represents the - // compilation target. // - Enable EP library to provide hardware devices not discovered by ORT. + // - EP may provide a virtual OrtHardwareDevice that represents a cross-compilation target. std::array ep_hw_devices{nullptr}; size_t num_ep_hw_devices = 0; diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index ad847a3af2590..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -16,7 +16,6 @@ #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" -#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" @@ -39,7 +38,7 @@ bool IsDiscreteDevice(const OrtEpDevice* d) { } const auto& entries = d->device->metadata.Entries(); - if (auto it = entries.find(kOrtHardwareDevice_MetadataKey_IsDiscrete); it != entries.end()) { + if (auto it = entries.find("Discrete"); it != entries.end()) { return it->second == "1"; } From 7d30fa7acc5a6e9b9de251bd0e038b9d143f6811 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 10:10:38 -0700 Subject: [PATCH 08/11] Make global keys const --- .../core/session/onnxruntime_ep_device_ep_metadata_keys.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index b4940a06bb784..fb8808326096c 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -21,10 +21,10 @@ static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path" // Possible values: // - "ONNX Runtime" (devices discovered by ONNX Runtime). // - (devices discovered by a plugin EP library registered with the OrtEnv). -static const char* kOrtHardwareDevice_MetadataKey_DiscoveredBy = "DiscoveredBy"; +static const char* const kOrtHardwareDevice_MetadataKey_DiscoveredBy = "DiscoveredBy"; // Key to determine if a OrtHardwareDevice represents a virtual (non-hardware) device. // Possible values: // - "0": OrtHardwareDevice is not virtual; represents an actual hardware device. // - "1": OrtHardwareDevice is virtual. -static const char* kOrtHardwareDevice_MetadataKey_IsVirtual = "IsVirtual"; +static const char* const kOrtHardwareDevice_MetadataKey_IsVirtual = "IsVirtual"; From f9514a3b3b79266d42a0abb5e7eafa3cbd844062 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 15:44:00 -0700 Subject: [PATCH 09/11] Add a new EP for testing --- cmake/onnxruntime_unittests.cmake | 49 ++- .../autoep/library/ep_lib_virtual_gpu/ep.cc | 320 ++++++++++++++++++ .../autoep/library/ep_lib_virtual_gpu/ep.h | 50 +++ .../library/ep_lib_virtual_gpu/ep_factory.cc | 226 +++++++++++++ .../library/ep_lib_virtual_gpu/ep_factory.h | 80 +++++ .../ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc | 50 +++ .../ep_lib_virtual_gpu/ep_lib_virtual_gpu.def | 5 + .../ep_lib_virtual_gpu.lds} | 0 .../library/{ => example_plugin_ep}/ep.cc | 0 .../library/{ => example_plugin_ep}/ep.h | 2 +- .../{ => example_plugin_ep}/ep_allocator.h | 2 +- .../{ => example_plugin_ep}/ep_arena.cc | 0 .../{ => example_plugin_ep}/ep_arena.h | 2 +- .../ep_data_transfer.cc | 0 .../ep_data_transfer.h | 2 +- .../{ => example_plugin_ep}/ep_factory.cc | 57 +--- .../{ => example_plugin_ep}/ep_factory.h | 2 +- .../ep_stream_support.cc | 0 .../ep_stream_support.h | 2 +- .../example_plugin_ep.cc | 0 .../example_plugin_ep_library.def | 0 .../example_plugin_ep_library.lds | 7 + .../autoep/library/example_plugin_ep_utils.cc | 48 --- ...le_plugin_ep_utils.h => plugin_ep_utils.h} | 44 ++- 24 files changed, 834 insertions(+), 114 deletions(-) create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc create mode 100644 onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def rename onnxruntime/test/autoep/library/{example_plugin_ep_library.lds => ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds} (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep.cc (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep.h (98%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_allocator.h (99%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_arena.cc (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_arena.h (99%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_data_transfer.cc (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_data_transfer.h (97%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_factory.cc (87%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_factory.h (99%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_stream_support.cc (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/ep_stream_support.h (98%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/example_plugin_ep.cc (100%) rename onnxruntime/test/autoep/library/{ => example_plugin_ep}/example_plugin_ep_library.def (100%) create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds delete mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_utils.cc rename onnxruntime/test/autoep/library/{example_plugin_ep_utils.h => plugin_ep_utils.h} (75%) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 460736ff8506e..350e1e6844daa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1982,9 +1982,13 @@ endif() if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + + # # example_plugin_ep - file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h" - "${TEST_SRC_DIR}/autoep/library/*.cc") + # + file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.cc" + "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep PRIVATE onnxruntime) @@ -1994,12 +1998,12 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG "-Xlinker -dead_strip") elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") string(CONCAT ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG - "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.lds " + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG - "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.def") + "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep/example_plugin_ep_library.def") endif() set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS @@ -2008,7 +2012,42 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src}) + # + # ep_lib_virtual_gpu + # + set(onnxruntime_autoep_test_ep_lib_virtual_gpu_src + "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h" + "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc" + "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_factory.h" + "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_factory.cc" + "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep.h" + "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep.cc") + onnxruntime_add_shared_library_module(ep_lib_virtual_gpu ${onnxruntime_autoep_test_ep_lib_virtual_gpu_src}) + target_include_directories(ep_lib_virtual_gpu PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + target_link_libraries(ep_lib_virtual_gpu PRIVATE onnxruntime) + + if(UNIX) + if (APPLE) + set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG "-Xlinker -dead_strip") + elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + string(CONCAT ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds " + "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + endif() + else() + set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG + "-DEF:${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def") + endif() + + set_property(TARGET ep_lib_virtual_gpu APPEND_STRING PROPERTY LINK_FLAGS + ${ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG}) + + set_target_properties(ep_lib_virtual_gpu PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_ep_lib_virtual_gpu_src}) + + # # test library + # file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") @@ -2041,7 +2080,7 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND TARGET onnxruntime_autoep_test SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src} LIBS ${onnxruntime_autoep_test_LIBS} - DEPENDS ${all_dependencies} example_plugin_ep + DEPENDS ${all_dependencies} example_plugin_ep ep_lib_virtual_gpu ) endif() diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc new file mode 100644 index 0000000000000..4df26f7ecb795 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include +#include +#include +#include +#include +#include + +#include "ep_factory.h" +#include "../plugin_ep_utils.h" + +/// +/// Example implementation of ONNX Add. Does not handle many things like broadcasting. +/// +struct AddImpl { + AddImpl(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} + + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); + } + + OrtStatus* Compute(OrtKernelContext* kernel_ctx) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + if (num_inputs != 2) { + throw Ort::Exception("Expected 2 inputs for AddImpl", ORT_INVALID_ARGUMENT); + } + + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } + + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for AddImpl", ORT_INVALID_ARGUMENT); + } + + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); + + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] + input1[i]; + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; + } + + const OrtApi& ort_api; + const OrtLogger& logger; +}; + +/// +/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +struct ExampleNodeComputeInfo : OrtNodeComputeInfo { + explicit ExampleNodeComputeInfo(EpVirtualGpu& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + EpVirtualGpu& ep; +}; + +EpVirtualGpu::EpVirtualGpu(EpFactoryVirtualGpu& factory, const OrtLogger& logger) + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + factory_{factory}, + ort_api_{factory.GetOrtApi()}, + ep_api_{factory.GetEpApi()}, + name_{factory.GetEpName()}, + logger_{logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + + auto status = ort_api_.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("EpVirtualGpu has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + // ignore status for now + (void)status; +} + +EpVirtualGpu::~EpVirtualGpu() = default; + +/*static*/ +const char* ORT_API_CALL EpVirtualGpu ::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL EpVirtualGpu::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + try { + EpVirtualGpu* ep = static_cast(this_ptr); + + Ort::ConstGraph graph{ort_graph}; + std::vector nodes = graph.GetNodes(); + if (nodes.empty()) { + return nullptr; // No nodes to process + } + + std::vector supported_nodes; + + for (const auto& node : nodes) { + auto op_type = node.GetOperatorType(); + + if (op_type == "Add") { + // Check that Add has inputs/output of type float + std::vector inputs = node.GetInputs(); + std::vector outputs = node.GetOutputs(); + + RETURN_IF(inputs.size() != 2 || outputs.size() != 1, ep->ort_api_, "Add should have 2 inputs and 1 output"); + + std::array is_float = {false, false, false}; + IsFloatTensor(inputs[0], is_float[0]); + IsFloatTensor(inputs[1], is_float[1]); + IsFloatTensor(outputs[0], is_float[2]); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } + + { + const auto input_0_shape = GetTensorShape(inputs[0]), + input_1_shape = GetTensorShape(inputs[1]); + + if (!input_0_shape.has_value() || !input_1_shape.has_value()) { + continue; // unable to get input shape + } + + const auto is_static_shape = [](gsl::span shape) -> bool { + return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; }); + }; + + if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) { + continue; // input shape has dynamic dimensions + } + + if (*input_0_shape != *input_1_shape) { + continue; // input shapes do not match (no broadcasting support for now) + } + } + + supported_nodes.push_back(node); // Only support a single Add for now. + break; + } + } + + if (supported_nodes.empty()) { + return nullptr; + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + node_fusion_options.drop_constant_initializers = false; + + RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpVirtualGpu::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** ort_graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** /*ep_context_nodes*/) noexcept { + try { + if (count != 1) { + Ort::Status status("Expected to compile a single graph", ORT_EP_FAIL); + return status.release(); + } + + EpVirtualGpu* ep = static_cast(this_ptr); + + Ort::ConstGraph graph{ort_graphs[0]}; + + std::vector nodes = graph.GetNodes(); + if (nodes.size() != 1) { + Ort::Status status("Expected to compile a single Add node", ORT_EP_FAIL); + return status.release(); + } + + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Add") { + Ort::Status status("Expected to compile a single Add node", ORT_EP_FAIL); + return status.release(); + } + + // Now we know we're compiling a single Add node. Create a computation kernel. + Ort::ConstNode fused_node{fused_nodes[0]}; + auto ep_name = fused_node.GetEpName(); + if (ep_name != ep->name_) { + Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); + return status.release(); + } + + // Associate the name of the fused node with our AddImpl. + auto fused_node_name = fused_node.GetName(); + ep->compiled_subgraphs_.emplace(std::move(fused_node_name), + std::make_unique(ep->ort_api_, ep->logger_)); + + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpVirtualGpu::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } +} + +// +// Implementation of ExampleNodeComputeInfo +// +ExampleNodeComputeInfo::ExampleNodeComputeInfo(EpVirtualGpu& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + EpVirtualGpu& ep = node_compute_info->ep; + + std::string fused_node_name = ep.GetEpApi().NodeComputeContext_NodeName(compute_context); + auto subgraph_it = ep.GetCompiledSubgraphs().find(fused_node_name); + if (subgraph_it == ep.GetCompiledSubgraphs().end()) { + std::string message = "Unable to get compiled subgraph for fused node with name " + fused_node_name; + return ep.GetOrtApi().CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + AddImpl& add_impl = *subgraph_it->second; + *compute_state = &add_impl; + return nullptr; +} + +OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + (void)this_ptr; + AddImpl& add_impl = *reinterpret_cast(compute_state); + return add_impl.Compute(kernel_context); +} + +void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + AddImpl& add_impl = *reinterpret_cast(compute_state); + (void)add_impl; + // Do nothing for this example. +} diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h new file mode 100644 index 0000000000000..538f2467f42ba --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +class EpFactoryVirtualGpu; +struct AddImpl; + +/// +/// Example EP for a virtual GPU OrtHardwareDevice that was created by the EP factory itself (not ORT). +/// Does not currently execute any nodes. Only used to test that an EP can provide ORT additional hardware devices. +/// +class EpVirtualGpu : public OrtEp { + public: + EpVirtualGpu(EpFactoryVirtualGpu& factory, const OrtLogger& logger); + ~EpVirtualGpu(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + + std::unordered_map>& GetCompiledSubgraphs() { + return compiled_subgraphs_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept; + + EpFactoryVirtualGpu& factory_; + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + std::string name_; + const OrtLogger& logger_; + std::unordered_map> compiled_subgraphs_; +}; diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc new file mode 100644 index 0000000000000..a130918b88d79 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_factory.h" + +#include + +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +#include "ep.h" + +EpFactoryVirtualGpu::EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& default_logger) + : OrtEpFactory{}, ort_api_(ort_api), ep_api_(ep_api), default_logger_{default_logger} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + GetAdditionalHardwareDevices = GetAdditionalHardwareDevicesImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +/*static*/ +uint32_t ORT_API_CALL EpFactoryVirtualGpu::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + +/*static*/ +const char* ORT_API_CALL EpFactoryVirtualGpu::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::GetAdditionalHardwareDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* found_devices, + size_t num_found_devices, + OrtHardwareDevice** additional_devices, + size_t max_additional_devices, + size_t* num_additional_devices) noexcept { + // EP factory can provide ORT with additional hardware devices that ORT did not find, or more likely, that are not + // available on the target machine but could serve as compilation targets. + + // As an example, this example EP factory will first look for a GPU device among the devices found by ORT. If there + // is no GPU available, then this EP will create a virtual GPU device that the application can use a compilation target. + + auto* factory = static_cast(this_ptr); + bool found_gpu = false; + + for (size_t i = 0; i < num_found_devices; ++i) { + const OrtHardwareDevice& device = *found_devices[i]; + + if (factory->ort_api_.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory->ort_api_.HardwareDevice_Vendor(&device) == factory->vendor_) { + found_gpu = true; + break; + } + } + + *num_additional_devices = 0; + + if (!found_gpu && max_additional_devices >= 1) { + // Create a new HW device. + OrtKeyValuePairs* hw_metadata = nullptr; + factory->ort_api_.CreateKeyValuePairs(&hw_metadata); + factory->ort_api_.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_DiscoveredBy, + factory->ep_name_.c_str()); + factory->ort_api_.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_IsVirtual, "1"); + + OrtHardwareDevice* new_device = nullptr; + auto* status = factory->ep_api_.CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, + factory->vendor_id_, + /*device_id*/ 0, + factory->vendor_.c_str(), + hw_metadata, + &new_device); + factory->ort_api_.ReleaseKeyValuePairs(hw_metadata); // Release since ORT makes a copy. + + if (status != nullptr) { + return status; + } + + // ORT will take ownership of the new HW device. + additional_devices[0] = new_device; + *num_additional_devices = 1; + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api_.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory->ort_api_.HardwareDevice_Vendor(&device) == factory->vendor_) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + factory->ort_api_.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api_.AddKeyValuePair(ep_metadata, "ex_key", "ex_value"); + factory->ort_api_.AddKeyValuePair(ep_options, "compile_optimization", "O3"); + + // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ort_api_.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_device); + + factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api_.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* /*session_options*/, + const OrtLogger* logger, + OrtEp** ep) noexcept { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + // we only registered for GPU and only expected to be selected for one GPU + return factory->ort_api_.CreateStatus(ORT_INVALID_ARGUMENT, + "EpFactoryVirtualGpu only supports selection for one device."); + } + + auto actual_ep = std::make_unique(*factory, *logger); + + *ep = actual_ep.release(); + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpFactoryVirtualGpu::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + EpVirtualGpu* dummy_ep = static_cast(ep); + delete dummy_ep; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateAllocatorImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + // Don't support custom allocators in this example for simplicity. A GPU EP would normally support allocators. + *allocator = nullptr; + return nullptr; +} + +/*static*/ +void ORT_API_CALL EpFactoryVirtualGpu::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, + OrtAllocator* /*allocator*/) noexcept { + // Do nothing. +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + // Don't support data transfer in this example for simplicity. A GPU EP would normally support it. + *data_transfer = nullptr; + return nullptr; +} + +/*static*/ +bool ORT_API_CALL EpFactoryVirtualGpu::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; +} + +/*static*/ +OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::CreateSyncStreamForDeviceImpl(OrtEpFactory* /*this_ptr*/, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + // Don't support sync streams in this example. A GPU EP would normally support it. + *stream = nullptr; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h new file mode 100644 index 0000000000000..a62d9f5300544 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +/// +/// EP factory that creates an OrtEp instance that supports a virtual GPU OrtHardwareDevice +/// created by the factory itself (not ORT). +/// +class EpFactoryVirtualGpu : public OrtEpFactory { + public: + EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, const OrtLogger& default_logger); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + const std::string& GetEpName() const { return ep_name_; } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetAdditionalHardwareDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* found_devices, + size_t num_found_devices, + OrtHardwareDevice** additional_devices, + size_t max_additional_devices, + size_t* num_additional_devices) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtLogger& default_logger_; + const std::string ep_name_{"EpVirtualGpu"}; + const std::string vendor_{"Contoso2"}; // EP vendor name + const uint32_t vendor_id_{0xB358}; // EP vendor ID + const std::string ep_version_{"0.1.0"}; // EP version +}; diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc new file mode 100644 index 0000000000000..3a326aa956ae9 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "ep_factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + + // Manual init for the C++ API + Ort::InitApi(ort_api); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, + *default_logger); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def new file mode 100644 index 0000000000000..ac3a951684683 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def @@ -0,0 +1,5 @@ +LIBRARY "ep_lib_virtual_gpu.dll" +EXPORTS + CreateEpFactories @1 + ReleaseEpFactory @2 + diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_library.lds b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep_library.lds rename to onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep.cc diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h similarity index 98% rename from onnxruntime/test/autoep/library/ep.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 279925a7ec3e1..7e96a523cf285 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -5,7 +5,7 @@ #include -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" class ExampleEpFactory; struct MulKernel; diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_allocator.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h index e46c03dfc8f14..febf8c7dbd8c1 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_allocator.h @@ -3,7 +3,7 @@ #pragma once -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" #include diff --git a/onnxruntime/test/autoep/library/ep_arena.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_arena.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.cc diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_arena.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h index caa2c61db835f..c8fd1db5dc007 100644 --- a/onnxruntime/test/autoep/library/ep_arena.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_arena.h @@ -26,7 +26,7 @@ limitations under the License. #undef ORT_API_MANUAL_INIT #include "ep_allocator.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" #if defined(PLATFORM_WINDOWS) #include diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_data_transfer.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.cc diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h similarity index 97% rename from onnxruntime/test/autoep/library/ep_data_transfer.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h index da74d42b4affe..f1dad784ff84b 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs { ExampleDataTransfer(ApiPtrs api_ptrs, diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc similarity index 87% rename from onnxruntime/test/autoep/library/ep_factory.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 968c6b68d5140..69293df7cda32 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -101,59 +101,17 @@ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* th } /*static*/ -OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEpFactory* this_ptr, - const OrtHardwareDevice* const* found_devices, - size_t num_found_devices, - OrtHardwareDevice** additional_devices, - size_t max_additional_devices, +OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEpFactory* /*this_ptr*/, + const OrtHardwareDevice* const* /*found_devices*/, + size_t /*num_found_devices*/, + OrtHardwareDevice** /*additional_devices*/, + size_t /*max_additional_devices*/, size_t* num_additional_devices) noexcept { // EP factory can provide ORT with additional hardware devices that ORT did not find, or more likely, that are not // available on the target machine but could serve as compilation targets. - // As an example, this example EP factory will first look for a GPU device among the devices found by ORT. If there - // is no GPU available, then this EP will create a virtual GPU device that the application can use a compilation target. - - auto* factory = static_cast(this_ptr); - bool found_gpu = false; - - for (size_t i = 0; i < num_found_devices; ++i) { - const OrtHardwareDevice& device = *found_devices[i]; - - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_Vendor(&device) == factory->vendor_) { - found_gpu = true; - break; - } - } - + // This example EP does not provide any additional hardware devices. *num_additional_devices = 0; - - if (!found_gpu && max_additional_devices >= 1) { - // Create a new HW device. - OrtKeyValuePairs* hw_metadata = nullptr; - factory->ort_api.CreateKeyValuePairs(&hw_metadata); - factory->ort_api.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_DiscoveredBy, - factory->ep_name_.c_str()); - factory->ort_api.AddKeyValuePair(hw_metadata, kOrtHardwareDevice_MetadataKey_IsVirtual, "1"); - - OrtHardwareDevice* new_device = nullptr; - auto* status = factory->ep_api.CreateHardwareDevice(OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, - factory->vendor_id_, - /*device_id*/ 0, - factory->vendor_.c_str(), - hw_metadata, - &new_device); - factory->ort_api.ReleaseKeyValuePairs(hw_metadata); // Release since ORT makes a copy. - - if (status != nullptr) { - return status; - } - - // ORT will take ownership of the new HW device. - additional_devices[0] = new_device; - *num_additional_devices = 1; - } - return nullptr; } @@ -250,8 +208,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, // Create EP configuration from session options, if needed. // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. std::string ep_context_enable; - RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options, - "ep.context_enable", "0", ep_context_enable)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, "ep.context_enable", "0", ep_context_enable)); ExampleEp::Config config = {}; config.enable_ep_context = ep_context_enable == "1"; diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h similarity index 99% rename from onnxruntime/test/autoep/library/ep_factory.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 2d6146b753255..8a7389f9c1239 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -7,7 +7,7 @@ #include "ep_arena.h" #include "ep_data_transfer.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_stream_support.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.cc diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h similarity index 98% rename from onnxruntime/test/autoep/library/ep_stream_support.h rename to onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h index a825e5afd2250..847ed708c5ca7 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_stream_support.h @@ -5,7 +5,7 @@ #include "onnxruntime_c_api.h" #include "ep_factory.h" -#include "example_plugin_ep_utils.h" +#include "../plugin_ep_utils.h" class ExampleEpFactory; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep.cc similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep.cc rename to onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep.cc diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_library.def b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.def similarity index 100% rename from onnxruntime/test/autoep/library/example_plugin_ep_library.def rename to onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.def diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds new file mode 100644 index 0000000000000..a6d2ef09a7b16 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/example_plugin_ep_library.lds @@ -0,0 +1,7 @@ +VERS_1.0.0 { + global: + CreateEpFactories; + ReleaseEpFactory; + local: + *; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc deleted file mode 100644 index 8b36f5f4e9a13..0000000000000 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "example_plugin_ep_utils.h" - -#include - -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& /* ort_api */, const OrtSessionOptions& session_options, - const char* config_key, const std::string& default_val, - /*out*/ std::string& config_val) { - try { - Ort::ConstSessionOptions sess_opt{&session_options}; - config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } - - return nullptr; -} - -void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { - result = false; - - auto type_info = value_info.TypeInfo(); - ONNXType onnx_type = type_info.GetONNXType(); - if (onnx_type != ONNX_TYPE_TENSOR) { - return; - } - - auto type_shape = type_info.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return; - } - result = true; -} - -std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { - const auto type_info = value_info.TypeInfo(); - const auto onnx_type = type_info.GetONNXType(); - if (onnx_type != ONNX_TYPE_TENSOR) { - return std::nullopt; - } - - const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); - return type_shape.GetShape(); -} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/plugin_ep_utils.h similarity index 75% rename from onnxruntime/test/autoep/library/example_plugin_ep_utils.h rename to onnxruntime/test/autoep/library/plugin_ep_utils.h index decc89251dc7b..2024c5185b0d6 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/plugin_ep_utils.h @@ -104,12 +104,46 @@ struct FloatInitializer { }; // Returns an entry in the session option configurations, or a default value if not present. -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, - const char* config_key, const std::string& default_val, - /*out*/ std::string& config_val); +inline OrtStatus* GetSessionConfigEntryOrDefault(const OrtSessionOptions& session_options, + const char* config_key, const std::string& default_val, + /*out*/ std::string& config_val) { + try { + Ort::ConstSessionOptions sess_opt{&session_options}; + config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } + + return nullptr; +} // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. -void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); +inline void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { + result = false; + + auto type_info = value_info.TypeInfo(); + ONNXType onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return; + } + + auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + return; + } + result = true; +} // Gets the tensor shape from `value_info`. Returns std::nullopt if `value_info` is not a tensor. -std::optional> GetTensorShape(Ort::ConstValueInfo value_info); +inline std::optional> GetTensorShape(Ort::ConstValueInfo value_info) { + const auto type_info = value_info.TypeInfo(); + const auto onnx_type = type_info.GetONNXType(); + if (onnx_type != ONNX_TYPE_TENSOR) { + return std::nullopt; + } + + const auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + return type_shape.GetShape(); +} From b5538583606c38ac5d12edeb7098d4dc9768545a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 15:44:26 -0700 Subject: [PATCH 10/11] Rename class --- .../autoep/library/ep_lib_virtual_gpu/ep.cc | 25 ++++++++++--------- .../library/example_plugin_ep/ep_factory.cc | 4 +-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc index 4df26f7ecb795..a82556767466a 100644 --- a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc +++ b/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc @@ -15,6 +15,7 @@ /// /// Example implementation of ONNX Add. Does not handle many things like broadcasting. +/// Used as the implementation of a compiled subgraph with a single Add node. /// struct AddImpl { AddImpl(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} @@ -87,8 +88,8 @@ struct AddImpl { /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. /// -struct ExampleNodeComputeInfo : OrtNodeComputeInfo { - explicit ExampleNodeComputeInfo(EpVirtualGpu& ep); +struct AddNodeComputeInfo : OrtNodeComputeInfo { + explicit AddNodeComputeInfo(EpVirtualGpu& ep); static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, @@ -254,7 +255,7 @@ OrtStatus* ORT_API_CALL EpVirtualGpu::CompileImpl(_In_ OrtEp* this_ptr, _In_ con std::make_unique(ep->ort_api_, ep->logger_)); // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); + auto node_compute_info = std::make_unique(*ep); node_compute_infos[0] = node_compute_info.release(); } catch (const Ort::Exception& ex) { Ort::Status status(ex); @@ -278,19 +279,19 @@ void ORT_API_CALL EpVirtualGpu::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, } // -// Implementation of ExampleNodeComputeInfo +// Implementation of AddNodeComputeInfo // -ExampleNodeComputeInfo::ExampleNodeComputeInfo(EpVirtualGpu& ep) : ep(ep) { +AddNodeComputeInfo::AddNodeComputeInfo(EpVirtualGpu& ep) : ep(ep) { ort_version_supported = ORT_API_VERSION; CreateState = CreateStateImpl; Compute = ComputeImpl; ReleaseState = ReleaseStateImpl; } -OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, - OrtNodeComputeContext* compute_context, - void** compute_state) { - auto* node_compute_info = static_cast(this_ptr); +OrtStatus* AddNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); EpVirtualGpu& ep = node_compute_info->ep; std::string fused_node_name = ep.GetEpApi().NodeComputeContext_NodeName(compute_context); @@ -305,14 +306,14 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, return nullptr; } -OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, - OrtKernelContext* kernel_context) { +OrtStatus* AddNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { (void)this_ptr; AddImpl& add_impl = *reinterpret_cast(compute_state); return add_impl.Compute(kernel_context); } -void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { +void AddNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { (void)this_ptr; AddImpl& add_impl = *reinterpret_cast(compute_state); (void)add_impl; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 69293df7cda32..69634faea6870 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -21,7 +21,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; - GetAdditionalHardwareDevices = GetAdditionalHardwareDevicesImpl; + GetAdditionalHardwareDevices = GetAdditionalHardwareDevicesImpl; // optional. can be null. GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -107,7 +107,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetAdditionalHardwareDevicesImpl(OrtEp OrtHardwareDevice** /*additional_devices*/, size_t /*max_additional_devices*/, size_t* num_additional_devices) noexcept { - // EP factory can provide ORT with additional hardware devices that ORT did not find, or more likely, that are not + // EP factory can provide ORT with additional hardware devices that ORT did not find, or that are not // available on the target machine but could serve as compilation targets. // This example EP does not provide any additional hardware devices. From 24316d97dcf41d9708035e617c4c1c730c003ca4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 9 Oct 2025 17:05:56 -0700 Subject: [PATCH 11/11] Add test for new test EP --- cmake/onnxruntime_unittests.cmake | 32 +++++----- .../ep.cc | 0 .../ep.h | 2 +- .../ep_factory.cc | 0 .../ep_factory.h | 0 .../ep_lib.def} | 2 +- .../ep_lib.lds} | 0 .../ep_lib_entry.cc} | 1 - onnxruntime/test/autoep/test_allocators.cc | 2 +- onnxruntime/test/autoep/test_autoep_utils.cc | 36 +++++++++-- onnxruntime/test/autoep/test_autoep_utils.h | 18 +++--- onnxruntime/test/autoep/test_data_transfer.cc | 2 +- onnxruntime/test/autoep/test_execution.cc | 8 +-- onnxruntime/test/autoep/test_registration.cc | 64 ++++++++++++++++--- 14 files changed, 117 insertions(+), 50 deletions(-) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu => example_plugin_ep_virt_gpu}/ep.cc (100%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu => example_plugin_ep_virt_gpu}/ep.h (94%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu => example_plugin_ep_virt_gpu}/ep_factory.cc (100%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu => example_plugin_ep_virt_gpu}/ep_factory.h (100%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu/ep_lib_virtual_gpu.def => example_plugin_ep_virt_gpu/ep_lib.def} (56%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds => example_plugin_ep_virt_gpu/ep_lib.lds} (100%) rename onnxruntime/test/autoep/library/{ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc => example_plugin_ep_virt_gpu/ep_lib_entry.cc} (95%) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 350e1e6844daa..57ad1e597a205 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2013,37 +2013,37 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src}) # - # ep_lib_virtual_gpu + # example_plugin_ep_virt_gpu # - set(onnxruntime_autoep_test_ep_lib_virtual_gpu_src + set(onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h" - "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc" - "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_factory.h" - "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_factory.cc" - "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep.h" - "${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep.cc") - onnxruntime_add_shared_library_module(ep_lib_virtual_gpu ${onnxruntime_autoep_test_ep_lib_virtual_gpu_src}) - target_include_directories(ep_lib_virtual_gpu PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) - target_link_libraries(ep_lib_virtual_gpu PRIVATE onnxruntime) + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep.cc") + onnxruntime_add_shared_library_module(example_plugin_ep_virt_gpu ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src}) + target_include_directories(example_plugin_ep_virt_gpu PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + target_link_libraries(example_plugin_ep_virt_gpu PRIVATE onnxruntime) if(UNIX) if (APPLE) set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG "-Xlinker -dead_strip") elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") string(CONCAT ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG - "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds " + "-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() set(ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG - "-DEF:${TEST_SRC_DIR}/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def") + "-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def") endif() - set_property(TARGET ep_lib_virtual_gpu APPEND_STRING PROPERTY LINK_FLAGS + set_property(TARGET example_plugin_ep_virt_gpu APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_AUTOEP_EP_LIB_VIRT_GPU_LINK_FLAG}) - set_target_properties(ep_lib_virtual_gpu PROPERTIES FOLDER "ONNXRuntimeTest") - source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_ep_lib_virtual_gpu_src}) + set_target_properties(example_plugin_ep_virt_gpu PROPERTIES FOLDER "ONNXRuntimeTest") + source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_example_plugin_ep_virt_gpu_src}) # # test library @@ -2080,7 +2080,7 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND TARGET onnxruntime_autoep_test SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src} LIBS ${onnxruntime_autoep_test_LIBS} - DEPENDS ${all_dependencies} example_plugin_ep ep_lib_virtual_gpu + DEPENDS ${all_dependencies} example_plugin_ep example_plugin_ep_virt_gpu ) endif() diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.cc rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.cc diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h similarity index 94% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h index 538f2467f42ba..736b6dd1fdb45 100644 --- a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep.h @@ -12,7 +12,7 @@ struct AddImpl; /// /// Example EP for a virtual GPU OrtHardwareDevice that was created by the EP factory itself (not ORT). -/// Does not currently execute any nodes. Only used to test that an EP can provide ORT additional hardware devices. +/// Can only compile/execute a single Add node. Only used to test that an EP can provide additional hardware devices. /// class EpVirtualGpu : public OrtEp { public: diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc similarity index 100% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.cc rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h similarity index 100% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_factory.h rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def similarity index 56% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def index ac3a951684683..e9481d0d60b28 100644 --- a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.def +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.def @@ -1,4 +1,4 @@ -LIBRARY "ep_lib_virtual_gpu.dll" +LIBRARY "example_plugin_ep_virt_gpu.dll" EXPORTS CreateEpFactories @1 ReleaseEpFactory @2 diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds similarity index 100% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.lds rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib.lds diff --git a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc similarity index 95% rename from onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc rename to onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc index 3a326aa956ae9..1b88b77280e88 100644 --- a/onnxruntime/test/autoep/library/ep_lib_virtual_gpu/ep_lib_virtual_gpu.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc @@ -27,7 +27,6 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, co // Manual init for the C++ API Ort::InitApi(ort_api); - // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, *default_logger); diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 88b522eb10dca..3c73237708828 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -61,7 +61,7 @@ struct DummyAllocator : OrtAllocator { // validate CreateSharedAllocator allows adding an arena to the shared allocator TEST(SharedAllocators, AddArenaToSharedAllocator) { RegisteredEpDeviceUniquePtr example_ep; - Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep); Ort::ConstEpDevice example_ep_device{example_ep.get()}; diff --git a/onnxruntime/test/autoep/test_autoep_utils.cc b/onnxruntime/test/autoep/test_autoep_utils.cc index 7045ccca2f576..af7d9e1853c0f 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.cc +++ b/onnxruntime/test/autoep/test_autoep_utils.cc @@ -15,7 +15,28 @@ namespace onnxruntime { namespace test { -Utils::ExamplePluginInfo Utils::example_ep_info; +Utils::ExamplePluginInfo::ExamplePluginInfo(const ORTCHAR_T* lib_path, const char* reg_name, const char* ep_name) + : library_path(lib_path), registration_name(reg_name), ep_name(ep_name) {} + +const Utils::ExamplePluginInfo Utils::example_ep_info( +#if _WIN32 + ORT_TSTR("example_plugin_ep.dll"), +#else + ORT_TSTR("libexample_plugin_ep.so"), +#endif + // The example_plugin_ep always uses the registration name as the EP name. + "example_ep", + "example_ep"); + +const Utils::ExamplePluginInfo Utils::example_ep_virt_gpu_info( +#if _WIN32 + ORT_TSTR("example_plugin_ep_virt_gpu.dll"), +#else + "libexample_plugin_ep_virt_gpu.so", +#endif + "example_plugin_ep_virt_gpu", + // This EP's name is hardcoded to the following + "EpVirtualGpu"); void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { const OrtApi& c_api = Ort::GetApi(); @@ -36,18 +57,19 @@ void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& } } -void Utils::RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { +void Utils::RegisterAndGetExampleEp(Ort::Env& env, const ExamplePluginInfo& ep_info, + RegisteredEpDeviceUniquePtr& registered_ep) { const OrtApi& c_api = Ort::GetApi(); // this should load the library and create OrtEpDevice ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, - example_ep_info.registration_name.c_str(), - example_ep_info.library_path.c_str())); + ep_info.registration_name.c_str(), + ep_info.library_path.c_str())); const OrtEpDevice* example_ep = nullptr; - GetEp(env, example_ep_info.registration_name, example_ep); + GetEp(env, ep_info.ep_name, example_ep); ASSERT_NE(example_ep, nullptr); - registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { - c_api.UnregisterExecutionProviderLibrary(env, example_ep_info.registration_name.c_str()); + registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, &ep_info, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, ep_info.registration_name.c_str()); }); } diff --git a/onnxruntime/test/autoep/test_autoep_utils.h b/onnxruntime/test/autoep/test_autoep_utils.h index 2dd7b5f0428e2..f6b5e3623505f 100644 --- a/onnxruntime/test/autoep/test_autoep_utils.h +++ b/onnxruntime/test/autoep/test_autoep_utils.h @@ -15,23 +15,23 @@ using RegisteredEpDeviceUniquePtr = std::unique_ptrEpDevice_EpName(device) == registration_name; + [&ep_name, &c_api](const OrtEpDevice* device) { + return c_api->EpDevice_EpName(device) == ep_name; }); ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; @@ -50,6 +49,7 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { const std::filesystem::path& library_path = Utils::example_ep_info.library_path; const std::string& registration_name = Utils::example_ep_info.registration_name; + const std::string& ep_name = Utils::example_ep_info.ep_name; // this should load the library and create OrtEpDevice ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); @@ -58,14 +58,13 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { // should be one device for the example EP auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), - [®istration_name](Ort::ConstEpDevice& device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return device.EpName() == registration_name; + [&ep_name](Ort::ConstEpDevice& device) { + return device.EpName() == ep_name; }); ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; - // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + // test all the C++ getters. + // expected values are from \onnxruntime\test\autoep\library\example_plugin_ep\*.cc ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); auto metadata = test_ep_device->EpMetadata(); @@ -89,6 +88,53 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); } +// Test loading example_plugin_ep_virt_gpu and its associated OrtEpDevice/OrtHardwareDevice. +// This EP creates a new OrtHardwareDevice instance that represents a virtual GPU and gives to ORT. +TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::example_ep_virt_gpu_info.library_path; + const std::string& registration_name = Utils::example_ep_virt_gpu_info.registration_name; + const std::string& ep_name = Utils::example_ep_virt_gpu_info.ep_name; + + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // should be one device for the example EP + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [&ep_name](Ort::ConstEpDevice& device) { + return device.EpName() == ep_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by " << ep_name; + + // test all the C++ getters. + // expected values are from \onnxruntime\test\autoep\library\example_plugin_ep_virt_gpu\*.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso2"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("ex_key"), "ex_value"); + + auto options = test_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("compile_optimization"), "O3"); + + // Check the virtual GPU device info. + Ort::ConstHardwareDevice virt_gpu_device = test_ep_device->Device(); + ASSERT_EQ(virt_gpu_device.Type(), OrtHardwareDeviceType_GPU); + ASSERT_EQ(virt_gpu_device.VendorId(), 0xB358); + ASSERT_EQ(virt_gpu_device.DeviceId(), 0); + ASSERT_STREQ(virt_gpu_device.Vendor(), test_ep_device->EpVendor()); + + // OrtHardwareDevice should have 2 metadata entries ("DiscoveredBy" and "IsVirtual") + Ort::ConstKeyValuePairs device_metadata = virt_gpu_device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_EQ(metadata_entries.size(), 2); + ASSERT_EQ(metadata_entries[kOrtHardwareDevice_MetadataKey_DiscoveredBy], ep_name); + ASSERT_EQ(metadata_entries[kOrtHardwareDevice_MetadataKey_IsVirtual], "1"); + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} } // namespace test } // namespace onnxruntime