Skip to content

Commit abc63e8

Browse files
authored
Implement new Python APIs (#25999)
### Description <!-- Describe your changes. --> This pull request introduces several enhancements to ONNX Runtime's Python and C++ APIs, focusing on improved device and memory information handling, synchronization stream support, and tensor copy functionality. It adds new Python bindings for device/memory types, exposes more detailed session input/output metadata, and provides a Python-accessible tensor copy API. The changes also refactor and extend the C++ API for better stream and memory info management. Key changes include: ### Device and Memory Information Enhancements * Added Python bindings for `OrtMemoryInfoDeviceType`, `OrtDeviceMemoryType`, and expanded `OrtDevice` to expose the memory type via a new `mem_type` method. The `OrtMemoryInfo` Python class now supports both legacy and new V2 constructors and exposes additional properties such as device memory type and vendor ID. [[1]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eR1801-R1810) [[2]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eR1839) [[3]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eL1941-R2005) * Extended the Python `InferenceSession` object to provide access to input/output `OrtMemoryInfo` and `OrtEpDevice` objects through new properties and methods. [[1]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eR2702-R2729) [[2]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR202-R213) [[3]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR591-R593) [[4]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR607-R609) ### Synchronization Stream and Execution Provider Device Support * Introduced Python bindings for `OrtSyncStream`, including creation via `OrtEpDevice.create_sync_stream()` and retrieval of device-specific `OrtMemoryInfo` via `OrtEpDevice.memory_info()`. [[1]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eR1890-R1938) [[2]](diffhunk://#diff-44e70fbe60cba71c94f1a46ec2b1facaa8e9475232dad6df5ecbea301e76d475R34-R44) * Refactored the C++ API to generalize `SyncStream` handling, allowing for unowned streams and improved type safety. [[1]](diffhunk://#diff-17f64e8b38fcdcd25e90abcabeec4b420956b15fe63868a5d0b270c376bde209L1066-R1084) [[2]](diffhunk://#diff-cc93f5f9d8078d3d3af14c9bb4c0c59e25a99f3ec75d7772ea20111ed7eb6ddeL672-R677) ### Tensor Copy Functionality * Added a new Python-level `copy_tensors` function and corresponding C++ binding, enabling efficient copying of tensor data between `OrtValue` objects, optionally using a synchronization stream. [[1]](diffhunk://#diff-c46fc0e05521f706449c04aed599ac0229012c007a78b584519e71a57601d63eR1588-R1599) [[2]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR1155-R1163) [[3]](diffhunk://#diff-44e70fbe60cba71c94f1a46ec2b1facaa8e9475232dad6df5ecbea301e76d475R84) ### Miscellaneous Improvements and Fixes * Changed the return type of the `OrtValue.data_ptr` method in the Python binding from `int64_t` to `uintptr_t` for better cross-platform compatibility. [[1]](diffhunk://#diff-666c9002698d1bbd4215237231e5be98d7b33e5054f018dce952407027bd0473L336-R336) [[2]](diffhunk://#diff-666c9002698d1bbd4215237231e5be98d7b33e5054f018dce952407027bd0473L347-R347) * Minor improvements to error messages and device type handling in the Python API (e.g., for `OrtDevice`). [[1]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR1176) [[2]](diffhunk://#diff-f0e8ba8cb8cb07b51b3be675bf62cec07e2eae1461341ce5801d33a57c8f57fdR1219-R1221) * Included necessary C++ includes for plugin stream support. These changes collectively improve the flexibility and introspection capabilities of ONNX Runtime's device, memory, and execution provider interfaces, and make advanced features available to Python users. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Depends on: #26021
1 parent b72dd15 commit abc63e8

File tree

8 files changed

+276
-41
lines changed

8 files changed

+276
-41
lines changed

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,11 +1063,25 @@ using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
10631063
/** \brief Wrapper around ::OrtSyncStream
10641064
*
10651065
*/
1066-
struct SyncStream : detail::Base<OrtSyncStream> {
1067-
explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used
1068-
explicit SyncStream(OrtSyncStream* p) : Base<OrtSyncStream>{p} {} ///< Take ownership of a pointer created by C API
1069-
void* GetHandle() const; ///< Wraps SyncStream_GetHandle
1066+
1067+
namespace detail {
1068+
template <typename T>
1069+
struct SyncStreamImpl : Base<T> {
1070+
using B = Base<T>;
1071+
using B::B;
1072+
// For some reason this is not a const method on the stream
1073+
void* GetHandle(); ///< Wraps SyncStream_GetHandle
10701074
};
1075+
} // namespace detail
1076+
1077+
struct SyncStream : detail::SyncStreamImpl<OrtSyncStream> {
1078+
///< Create an empty SyncStream object, must be assigned a valid one to be used
1079+
explicit SyncStream(std::nullptr_t) {}
1080+
///< Take ownership of a pointer created by C API
1081+
explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl<OrtSyncStream>{p} {}
1082+
};
1083+
1084+
using UnownedSyncStream = detail::SyncStreamImpl<detail::Unowned<OrtSyncStream>>;
10711085

10721086
namespace detail {
10731087
template <typename T>

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) {
669669
GetApi().RemoveKeyValuePair(this->p_, key);
670670
}
671671

672-
inline void* SyncStream::GetHandle() const {
672+
namespace detail {
673+
template <typename T>
674+
inline void* SyncStreamImpl<T>::GetHandle() {
673675
return GetApi().SyncStream_GetHandle(this->p_);
674676
}
677+
} // namespace detail
675678

676679
namespace detail {
677680
template <typename T>

onnxruntime/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
OrtAllocatorType, # noqa: F401
3232
OrtArenaCfg, # noqa: F401
3333
OrtCompileApiFlags, # noqa: F401
34+
OrtDeviceMemoryType, # noqa: F401
3435
OrtEpDevice, # noqa: F401
3536
OrtExecutionProviderDevicePolicy, # noqa: F401
3637
OrtExternalInitializerInfo, # noqa: F401
3738
OrtHardwareDevice, # noqa: F401
3839
OrtHardwareDeviceType, # noqa: F401
3940
OrtMemoryInfo, # noqa: F401
41+
OrtMemoryInfoDeviceType, # noqa: F401
4042
OrtMemType, # noqa: F401
4143
OrtSparseFormat, # noqa: F401
44+
OrtSyncStream, # noqa: F401
4245
RunOptions, # noqa: F401
4346
SessionIOBinding, # noqa: F401
4447
SessionOptions, # noqa: F401
@@ -78,6 +81,7 @@
7881
OrtDevice, # noqa: F401
7982
OrtValue, # noqa: F401
8083
SparseTensor, # noqa: F401
84+
copy_tensors, # noqa: F401
8185
)
8286

8387
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end

onnxruntime/python/onnxruntime_inference_collection.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata:
199199
"Return the metadata. See :class:`onnxruntime.ModelMetadata`."
200200
return self._model_meta
201201

202+
def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
203+
"Return the memory info for the inputs."
204+
return self._input_meminfos
205+
206+
def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
207+
"Return the memory info for the outputs."
208+
return self._output_meminfos
209+
210+
def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]:
211+
"Return the execution providers for the inputs."
212+
return self._input_epdevices
213+
202214
def get_providers(self) -> Sequence[str]:
203215
"Return list of registered execution providers."
204216
return self._providers
@@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
576588
self._inputs_meta = self._sess.inputs_meta
577589
self._outputs_meta = self._sess.outputs_meta
578590
self._overridable_initializers = self._sess.overridable_initializers
591+
self._input_meminfos = self._sess.input_meminfos
592+
self._output_meminfos = self._sess.output_meminfos
593+
self._input_epdevices = self._sess.input_epdevices
579594
self._model_meta = self._sess.model_meta
580595
self._providers = self._sess.get_providers()
581596
self._provider_options = self._sess.get_provider_options()
@@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None:
589604
self._inputs_meta = None
590605
self._outputs_meta = None
591606
self._overridable_initializers = None
607+
self._input_meminfos = None
608+
self._output_meminfos = None
609+
self._input_epdevices = None
592610
self._model_meta = None
593611
self._providers = None
594612
self._provider_options = None
@@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None:
11341152
self._ortvalue.update_inplace(np_arr)
11351153

11361154

1155+
def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None:
1156+
"""
1157+
Copy tensor data from source OrtValue sequence to destination OrtValue sequence.
1158+
"""
1159+
c_sources = [s._get_c_value() for s in src]
1160+
c_dsts = [d._get_c_value() for d in dst]
1161+
C.copy_tensors(c_sources, c_dsts, stream)
1162+
1163+
11371164
class OrtDevice:
11381165
"""
11391166
A data structure that exposes the underlying C++ OrtDevice
@@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device):
11461173
if isinstance(c_ort_device, C.OrtDevice):
11471174
self._ort_device = c_ort_device
11481175
else:
1176+
# An end user won't hit this error
11491177
raise ValueError(
11501178
"`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
11511179
)
@@ -1188,6 +1216,9 @@ def device_type(self):
11881216
def device_vendor_id(self):
11891217
return self._ort_device.vendor_id()
11901218

1219+
def device_mem_type(self):
1220+
return self._ort_device.mem_type()
1221+
11911222

11921223
class SparseTensor:
11931224
"""

onnxruntime/python/onnxruntime_pybind_ortvalue.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) {
333333
})
334334
#endif
335335
// Get a pointer to Tensor data
336-
.def("data_ptr", [](OrtValue* ml_value) -> int64_t {
336+
.def("data_ptr", [](OrtValue* ml_value) -> uintptr_t {
337337
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
338338
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
339339

@@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) {
344344
}
345345

346346
// Should cover x86 and x64 platforms
347-
return reinterpret_cast<int64_t>(tensor->MutableDataRaw());
347+
return reinterpret_cast<uintptr_t>(tensor->MutableDataRaw());
348348
})
349349
.def("device_name", [](const OrtValue* ort_value) -> std::string {
350350
if (ort_value->IsTensor()) {

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "core/framework/data_transfer_utils.h"
2323
#include "core/framework/data_types_internal.h"
2424
#include "core/framework/error_code_helper.h"
25+
#include "core/framework/plugin_ep_stream.h"
2526
#include "core/framework/provider_options_utils.h"
2627
#include "core/framework/random_seed.h"
2728
#include "core/framework/sparse_tensor.h"
@@ -1587,6 +1588,18 @@ void addGlobalMethods(py::module& m) {
15871588
},
15881589
R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");
15891590

1591+
m.def(
1592+
"copy_tensors",
1593+
[](const std::vector<const OrtValue*>& src, const std::vector<OrtValue*>& dest, py::object& py_arg) {
1594+
const OrtEnv* ort_env = GetOrtEnv();
1595+
OrtSyncStream* stream = nullptr;
1596+
if (!py_arg.is_none()) {
1597+
stream = py_arg.cast<OrtSyncStream*>();
1598+
}
1599+
Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size()));
1600+
},
1601+
R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc");
1602+
15901603
#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
15911604
m.def(
15921605
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
@@ -1788,6 +1801,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
17881801
.value("CPU", OrtMemTypeCPU)
17891802
.value("DEFAULT", OrtMemTypeDefault);
17901803

1804+
py::enum_<OrtMemoryInfoDeviceType>(m, "OrtMemoryInfoDeviceType")
1805+
.value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU)
1806+
.value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU)
1807+
.value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU)
1808+
.value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA);
1809+
1810+
py::enum_<OrtDeviceMemoryType>(m, "OrtDeviceMemoryType")
1811+
.value("DEFAULT", OrtDeviceMemoryType_DEFAULT)
1812+
.value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE);
1813+
17911814
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc");
17921815
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::VendorId, OrtDevice::DeviceId>())
17931816
.def(py::init([](OrtDevice::DeviceType type,
@@ -1816,6 +1839,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
18161839
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
18171840
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
18181841
.def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc")
1842+
.def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc")
18191843
// generic device types that are typically used with a vendor id.
18201844
.def_static("cpu", []() { return OrtDevice::CPU; })
18211845
.def_static("gpu", []() { return OrtDevice::GPU; })
@@ -1866,36 +1890,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
18661890
},
18671891
R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc");
18681892

1893+
py::class_<OrtSyncStream> py_sync_stream(m, "OrtSyncStream",
1894+
R"pbdoc(Represents a synchronization stream for model inference.)pbdoc");
1895+
18691896
py::class_<OrtEpDevice> py_ep_device(m, "OrtEpDevice",
18701897
R"pbdoc(Represents a hardware device that an execution provider supports
18711898
for model inference.)pbdoc");
18721899
py_ep_device.def_property_readonly(
18731900
"ep_name",
1874-
[](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
1901+
[](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
18751902
R"pbdoc(The execution provider's name.)pbdoc")
18761903
.def_property_readonly(
18771904
"ep_vendor",
1878-
[](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
1905+
[](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
18791906
R"pbdoc(The execution provider's vendor name.)pbdoc")
18801907
.def_property_readonly(
18811908
"ep_metadata",
1882-
[](OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
1909+
[](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
18831910
return ep_device->ep_metadata.Entries();
18841911
},
18851912
R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc")
18861913
.def_property_readonly(
18871914
"ep_options",
1888-
[](OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
1915+
[](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
18891916
return ep_device->ep_options.Entries();
18901917
},
18911918
R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc")
18921919
.def_property_readonly(
18931920
"device",
1894-
[](OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
1921+
[](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
18951922
return *ep_device->device;
18961923
},
18971924
R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc",
1898-
py::return_value_policy::reference_internal);
1925+
py::return_value_policy::reference_internal)
1926+
.def(
1927+
"memory_info",
1928+
[](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* {
1929+
Ort::ConstEpDevice ep_dev(ep_device);
1930+
return static_cast<const OrtMemoryInfo*>(ep_dev.GetMemoryInfo(memory_type));
1931+
},
1932+
R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc",
1933+
py::return_value_policy::reference_internal)
1934+
.def(
1935+
"create_sync_stream",
1936+
[](const OrtEpDevice* ep_device) -> std::unique_ptr<OrtSyncStream> {
1937+
Ort::ConstEpDevice ep_dev(ep_device);
1938+
Ort::SyncStream stream = ep_dev.CreateSyncStream();
1939+
return std::unique_ptr<OrtSyncStream>(stream.release());
1940+
},
1941+
R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc");
18991942

19001943
py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
19011944
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
@@ -1941,25 +1984,28 @@ for model inference.)pbdoc");
19411984
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
19421985

19431986
py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
1944-
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
1945-
if (strcmp(name, onnxruntime::CPU) == 0) {
1946-
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), mem_type);
1947-
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
1948-
return std::make_unique<OrtMemoryInfo>(
1949-
onnxruntime::CUDA, type,
1950-
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
1951-
static_cast<OrtDevice::DeviceId>(id)),
1952-
mem_type);
1953-
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
1954-
return std::make_unique<OrtMemoryInfo>(
1955-
onnxruntime::CUDA_PINNED, type,
1956-
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
1957-
static_cast<OrtDevice::DeviceId>(id)),
1958-
mem_type);
1959-
} else {
1960-
throw std::runtime_error("Specified device is not supported.");
1961-
}
1962-
}));
1987+
ort_memory_info_binding.def(
1988+
py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
1989+
Ort::MemoryInfo result(name, type, id, mem_type);
1990+
return std::unique_ptr<OrtMemoryInfo>(result.release());
1991+
}))
1992+
.def_static(
1993+
"create_v2",
1994+
[](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id,
1995+
int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) {
1996+
Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type);
1997+
return std::unique_ptr<OrtMemoryInfo>(result.release());
1998+
},
1999+
R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc")
2000+
.def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc")
2001+
.def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc")
2002+
.def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc")
2003+
.def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc")
2004+
.def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType {
2005+
auto mem_type = mem_info->device.MemType();
2006+
return (mem_type == OrtDevice::MemType::DEFAULT) ?
2007+
OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc")
2008+
.def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); });
19632009

19642010
py::class_<PySessionOptions>
19652011
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
@@ -2699,6 +2745,33 @@ including arg name, arg type (contains both type and shape).)pbdoc")
26992745
auto res = sess->GetSessionHandle()->GetModelMetadata();
27002746
OrtPybindThrowIfError(res.first);
27012747
return *(res.second); }, py::return_value_policy::reference_internal)
2748+
.def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list {
2749+
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
2750+
auto inputs_mem_info = session.GetMemoryInfoForInputs();
2751+
py::list result;
2752+
for (const auto& info : inputs_mem_info) {
2753+
const auto* p_info = static_cast<const OrtMemoryInfo*>(info);
2754+
result.append(py::cast(p_info, py::return_value_policy::reference));
2755+
}
2756+
return result; })
2757+
.def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list {
2758+
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
2759+
auto outputs_mem_info = session.GetMemoryInfoForOutputs();
2760+
py::list result;
2761+
for (const auto& info : outputs_mem_info) {
2762+
const auto* p_info = static_cast<const OrtMemoryInfo*>(info);
2763+
result.append(py::cast(p_info, py::return_value_policy::reference));
2764+
}
2765+
return result; })
2766+
.def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list {
2767+
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
2768+
auto ep_devices = session.GetEpDeviceForInputs();
2769+
py::list result;
2770+
for (const auto& device : ep_devices) {
2771+
const auto* p_device = static_cast<const OrtEpDevice*>(device);
2772+
result.append(py::cast(p_device, py::return_value_policy::reference));
2773+
}
2774+
return result; })
27022775
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
27032776

27042777
Status status;

0 commit comments

Comments
 (0)