Skip to content

Commit c54be3c

Browse files
authored
Add OrtExternalResourceImporter API for D3D12 shared resource import (#26828)
1 parent a14231a commit c54be3c

20 files changed

+1996
-0
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 337 additions & 0 deletions
Large diffs are not rendered by default.

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ inline const OrtCompileApi& GetCompileApi() {
236236
return *api;
237237
}
238238

239+
/// <summary>
240+
/// This returns a reference to the ORT C Interop API. Used for external resource import with EPs.
241+
/// </summary>
242+
/// <returns>ORT C Interop API reference</returns>
243+
inline const OrtInteropApi& GetInteropApi() {
244+
auto* api = GetApi().GetInteropApi();
245+
if (api == nullptr) {
246+
// minimal build
247+
ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL);
248+
}
249+
250+
return *api;
251+
}
252+
239253
/// <summary>
240254
/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
241255
/// </summary>
@@ -1610,6 +1624,7 @@ struct ConstSessionImpl : Base<T> {
16101624
std::vector<ConstMemoryInfo> GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs
16111625
std::vector<ConstMemoryInfo> GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs
16121626
std::vector<ConstEpDevice> GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs
1627+
std::vector<ConstEpDevice> GetEpDeviceForOutputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForOutputs
16131628

16141629
/** \brief Returns a copy of input name at the specified index.
16151630
*

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,19 @@ inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForInputs() co
16601660
return input_devices;
16611661
}
16621662

1663+
template <typename T>
1664+
inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForOutputs() const {
1665+
auto num_outputs = GetOutputCount();
1666+
std::vector<ConstEpDevice> output_devices;
1667+
if (num_outputs > 0) {
1668+
output_devices.resize(num_outputs);
1669+
ThrowOnError(GetApi().SessionGetEpDeviceForOutputs(this->p_,
1670+
reinterpret_cast<const OrtEpDevice**>(output_devices.data()),
1671+
num_outputs));
1672+
}
1673+
return output_devices;
1674+
}
1675+
16631676
template <typename T>
16641677
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
16651678
uint64_t out;

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,66 @@ ORT_RUNTIME_CLASS(DataTransferImpl);
2424
ORT_RUNTIME_CLASS(SyncNotificationImpl);
2525
ORT_RUNTIME_CLASS(SyncStreamImpl);
2626

27+
ORT_RUNTIME_CLASS(ExternalResourceImporterImpl);
28+
29+
/** \brief Base struct for imported external memory handles.
30+
*
31+
* EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA).
32+
* EP is responsible for creating and releasing instances of the derived type.
33+
*
34+
* Example derived type for CUDA EP:
35+
* \code
36+
* struct MyCudaExternalMemoryHandle : OrtExternalMemoryHandle {
37+
* CUexternalMemory ext_memory;
38+
* CUdeviceptr mapped_ptr;
39+
* bool is_dedicated;
40+
* };
41+
* \endcode
42+
*
43+
* \since Version 1.24.
44+
*/
45+
struct OrtExternalMemoryHandle {
46+
uint32_t version; ///< Must be ORT_API_VERSION
47+
const OrtEpDevice* ep_device; ///< EP device that created this handle
48+
OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking
49+
size_t size_bytes; ///< Size of the imported memory
50+
size_t offset_bytes; ///< Offset into the imported memory
51+
52+
/** \brief Release callback for this handle. EP sets this to its release function.
53+
*
54+
* ORT calls this when ReleaseExternalMemoryHandle is invoked. The EP's callback
55+
* should cast the handle to its derived type and delete it.
56+
*/
57+
void(ORT_API_CALL* Release)(_In_ OrtExternalMemoryHandle* handle);
58+
};
59+
60+
/** \brief Base struct for imported external semaphore handles.
61+
*
62+
* EPs derive from this struct to add EP-specific fields (e.g., CUexternalSemaphore for CUDA).
63+
* EP is responsible for creating and releasing instances of the derived type.
64+
*
65+
* Example derived type for CUDA EP:
66+
* \code
67+
* struct MyCudaExternalSemaphoreHandle : OrtExternalSemaphoreHandle {
68+
* CUexternalSemaphore ext_semaphore;
69+
* };
70+
* \endcode
71+
*
72+
* \since Version 1.24.
73+
*/
74+
struct OrtExternalSemaphoreHandle {
75+
uint32_t version; ///< Must be ORT_API_VERSION
76+
const OrtEpDevice* ep_device; ///< EP device that created this handle
77+
OrtExternalSemaphoreType type; ///< Original semaphore type
78+
79+
/** \brief Release callback for this handle. EP sets this to its release function.
80+
*
81+
* ORT calls this when ReleaseExternalSemaphoreHandle is invoked. The EP's callback
82+
* should cast the handle to its derived type and delete it.
83+
*/
84+
void(ORT_API_CALL* Release)(_In_ OrtExternalSemaphoreHandle* handle);
85+
};
86+
2787
// Opaque types for kernel-based EPs
2888
ORT_RUNTIME_CLASS(KernelRegistry);
2989
ORT_RUNTIME_CLASS(KernelDefBuilder);
@@ -191,6 +251,180 @@ struct OrtSyncStreamImpl {
191251
ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr);
192252
};
193253

254+
/** \brief Struct that an EP implements for external resource import (memory + semaphore import).
255+
*
256+
* This capability object provides methods for importing external GPU memory and semaphores
257+
* for zero-copy import. EPs that support D3D12, CUDA, HIP, or Vulkan external resource APIs
258+
* can implement this interface.
259+
*
260+
* \since Version 1.24.
261+
*/
262+
struct OrtExternalResourceImporterImpl {
263+
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
264+
265+
// Memory operations (stream-independent)
266+
267+
/** \brief Check if the implementation can import external memory of the given handle type.
268+
*
269+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
270+
* \param[in] handle_type The type of external memory handle to check.
271+
* \return True if the handle type is supported.
272+
*
273+
* \since Version 1.24.
274+
*/
275+
ORT_API_T(bool, CanImportMemory,
276+
_In_ const OrtExternalResourceImporterImpl* this_ptr,
277+
_In_ OrtExternalMemoryHandleType handle_type);
278+
279+
/** \brief Import external memory.
280+
*
281+
* The EP creates a derived type of OrtExternalMemoryHandle and returns a pointer to the base.
282+
* EP is responsible for the lifetime of the handle (release via ReleaseMemory).
283+
*
284+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
285+
* \param[in] desc Descriptor containing the external memory handle and properties.
286+
* \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle (EP's derived type).
287+
*
288+
* \snippet{doc} snippets.dox OrtStatus Return Value
289+
*
290+
* \since Version 1.24.
291+
*/
292+
ORT_API2_STATUS(ImportMemory,
293+
_In_ OrtExternalResourceImporterImpl* this_ptr,
294+
_In_ const OrtExternalMemoryDescriptor* desc,
295+
_Outptr_ OrtExternalMemoryHandle** out_handle);
296+
297+
/** \brief Release an imported external memory handle.
298+
*
299+
* The EP deletes its derived type instance.
300+
*
301+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
302+
* \param[in] handle The OrtExternalMemoryHandle to release (EP casts to its derived type).
303+
*
304+
* \since Version 1.24.
305+
*/
306+
ORT_API_T(void, ReleaseMemory,
307+
_In_ OrtExternalResourceImporterImpl* this_ptr,
308+
_In_ OrtExternalMemoryHandle* handle);
309+
310+
/** \brief Create a tensor backed by imported external memory.
311+
*
312+
* The created tensor is a view over the imported memory and does not copy data.
313+
*
314+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
315+
* \param[in] mem_handle The imported external memory handle (EP casts to its derived type).
316+
* \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset.
317+
* \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor.
318+
*
319+
* \snippet{doc} snippets.dox OrtStatus Return Value
320+
*
321+
* \since Version 1.24.
322+
*/
323+
ORT_API2_STATUS(CreateTensorFromMemory,
324+
_In_ OrtExternalResourceImporterImpl* this_ptr,
325+
_In_ const OrtExternalMemoryHandle* mem_handle,
326+
_In_ const OrtExternalTensorDescriptor* tensor_desc,
327+
_Outptr_ OrtValue** out_tensor);
328+
329+
// Semaphore operations (require stream)
330+
331+
/** \brief Check if the implementation can import external semaphores of the given type.
332+
*
333+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
334+
* \param[in] type The type of external semaphore to check.
335+
* \return True if the semaphore type is supported.
336+
*
337+
* \since Version 1.24.
338+
*/
339+
ORT_API_T(bool, CanImportSemaphore,
340+
_In_ const OrtExternalResourceImporterImpl* this_ptr,
341+
_In_ OrtExternalSemaphoreType type);
342+
343+
/** \brief Import an external semaphore.
344+
*
345+
* The EP creates a derived type of OrtExternalSemaphoreHandle and returns a pointer to the base.
346+
* EP is responsible for the lifetime of the handle (release via ReleaseSemaphore).
347+
*
348+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
349+
* \param[in] desc Descriptor containing the external semaphore handle and type.
350+
* \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle (EP's derived type).
351+
*
352+
* \snippet{doc} snippets.dox OrtStatus Return Value
353+
*
354+
* \since Version 1.24.
355+
*/
356+
ORT_API2_STATUS(ImportSemaphore,
357+
_In_ OrtExternalResourceImporterImpl* this_ptr,
358+
_In_ const OrtExternalSemaphoreDescriptor* desc,
359+
_Outptr_ OrtExternalSemaphoreHandle** out_handle);
360+
361+
/** \brief Release an imported external semaphore handle.
362+
*
363+
* The EP deletes its derived type instance.
364+
*
365+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
366+
* \param[in] handle The OrtExternalSemaphoreHandle to release (EP casts to its derived type).
367+
*
368+
* \since Version 1.24.
369+
*/
370+
ORT_API_T(void, ReleaseSemaphore,
371+
_In_ OrtExternalResourceImporterImpl* this_ptr,
372+
_In_ OrtExternalSemaphoreHandle* handle);
373+
374+
/** \brief Wait on an external semaphore on the EP's stream.
375+
*
376+
* Inserts a wait operation into the EP's stream that blocks until the semaphore
377+
* reaches the specified value.
378+
*
379+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
380+
* \param[in] handle The imported external semaphore (EP casts to its derived type).
381+
* \param[in] stream The OrtSyncStream to wait on.
382+
* \param[in] value The fence/semaphore value to wait for.
383+
*
384+
* \snippet{doc} snippets.dox OrtStatus Return Value
385+
*
386+
* \since Version 1.24.
387+
*/
388+
ORT_API2_STATUS(WaitSemaphore,
389+
_In_ OrtExternalResourceImporterImpl* this_ptr,
390+
_In_ OrtExternalSemaphoreHandle* handle,
391+
_In_ OrtSyncStream* stream,
392+
_In_ uint64_t value);
393+
394+
/** \brief Signal an external semaphore from the EP's stream.
395+
*
396+
* Inserts a signal operation into the EP's stream that sets the semaphore
397+
* to the specified value when reached.
398+
*
399+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
400+
* \param[in] handle The imported external semaphore (EP casts to its derived type).
401+
* \param[in] stream The OrtSyncStream to signal from.
402+
* \param[in] value The fence/semaphore value to signal.
403+
*
404+
* \snippet{doc} snippets.dox OrtStatus Return Value
405+
*
406+
* \since Version 1.24.
407+
*/
408+
ORT_API2_STATUS(SignalSemaphore,
409+
_In_ OrtExternalResourceImporterImpl* this_ptr,
410+
_In_ OrtExternalSemaphoreHandle* handle,
411+
_In_ OrtSyncStream* stream,
412+
_In_ uint64_t value);
413+
414+
// Release the capability object itself
415+
416+
/** \brief Release the OrtExternalResourceImporterImpl instance.
417+
*
418+
* This is called by ORT when the OrtExternalResourceImporterImpl instance is no longer needed.
419+
* The implementation should release any resources held by the instance.
420+
*
421+
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
422+
*
423+
* \since Version 1.24.
424+
*/
425+
ORT_API_T(void, Release, _In_ OrtExternalResourceImporterImpl* this_ptr);
426+
};
427+
194428
struct OrtNodeFusionOptions;
195429
typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;
196430

@@ -1564,6 +1798,32 @@ struct OrtEpFactory {
15641798
* \since Version 1.24.
15651799
*/
15661800
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);
1801+
1802+
/** \brief Create an OrtExternalResourceImporterImpl for external resource import.
1803+
*
1804+
* This is used to create an external resource importer that enables zero-copy import of
1805+
* external GPU memory (e.g., D3D12 shared resources) and synchronization primitives
1806+
* (e.g., D3D12 timeline fences).
1807+
*
1808+
* EPs that support external resource import (via CUDA, HIP, Vulkan, or D3D12 APIs) can
1809+
* implement this to allow applications to share GPU resources without copies.
1810+
*
1811+
* \param[in] this_ptr The OrtEpFactory instance.
1812+
* \param[in] ep_device The OrtEpDevice to create the external resource importer for.
1813+
* \param[out] out_importer The created OrtExternalResourceImporterImpl instance.
1814+
* Set to nullptr if external resource import is not supported.
1815+
*
1816+
* \snippet{doc} snippets.dox OrtStatus Return Value
1817+
*
1818+
* \note Implementation of this function is optional.
1819+
* An EP factory should only implement this if it supports external resource import.
1820+
* If not implemented or not supported, return ORT_NOT_IMPLEMENTED or set out_importer to nullptr.
1821+
*
1822+
* \since Version 1.24.
1823+
*/
1824+
ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr,
1825+
_In_ const OrtEpDevice* ep_device,
1826+
_Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer);
15671827
};
15681828

15691829
#ifdef __cplusplus

onnxruntime/core/session/inference_session.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3480,6 +3480,48 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector<const OrtEpD
34803480
#endif
34813481
}
34823482

3483+
common::Status InferenceSession::GetEpDeviceForOutputs(InlinedVector<const OrtEpDevice*>& ep_devices) const {
3484+
ep_devices.clear();
3485+
3486+
#if defined(ORT_MINIMAL_BUILD)
3487+
return common::Status(common::ONNXRUNTIME, common::FAIL,
3488+
"GetEpDeviceForOutputs is not available in a minimal build.");
3489+
#else
3490+
if (!is_inited_) {
3491+
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized.");
3492+
}
3493+
3494+
std::pair<common::Status, const OutputDefList*> outputs = GetModelOutputs();
3495+
3496+
ORT_RETURN_IF_ERROR(outputs.first);
3497+
3498+
const auto& def_list = *outputs.second;
3499+
ep_devices.reserve(def_list.size());
3500+
3501+
const auto& available_eps = environment_.GetOrtEpDevices();
3502+
3503+
for (const auto* def : def_list) {
3504+
InlinedVector<SessionState::NodeInfo> node_info_vec;
3505+
ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec));
3506+
assert(!node_info_vec.empty());
3507+
// If we have an output that is not produced by any node,
3508+
// then we return nullptr.
3509+
const auto* p_node = node_info_vec.front().p_node;
3510+
if (p_node != nullptr) {
3511+
const auto ep_name = p_node->GetExecutionProviderType();
3512+
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
3513+
return entry->ep_name == ep_name;
3514+
});
3515+
ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
3516+
} else {
3517+
ep_devices.push_back(nullptr);
3518+
}
3519+
}
3520+
3521+
return Status::OK();
3522+
#endif
3523+
}
3524+
34833525
common::Status InferenceSession::NewIOBinding(std::unique_ptr<IOBinding>* io_binding) {
34843526
{
34853527
std::lock_guard<std::mutex> l(session_mutex_);

onnxruntime/core/session/inference_session.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,15 @@ class InferenceSession {
484484
* This is required for a user to know the location of the input/output when autoep selection is enabled.
485485
*/
486486
common::Status GetEpDeviceForInputs(InlinedVector<const OrtEpDevice*>& memory_info) const;
487+
488+
/**
489+
* Get the OrtEpDevice (if available) for the outputs of the model.
490+
*
491+
* This is required for a user to validate that outputs will be placed on the expected device
492+
* for external resource sharing.
493+
*/
494+
common::Status GetEpDeviceForOutputs(InlinedVector<const OrtEpDevice*>& memory_info) const;
495+
487496
/**
488497
* Get the current number of in-progress concurrent Run calls.
489498
*/

0 commit comments

Comments
 (0)