|
22 | 22 | #include "core/framework/data_transfer_utils.h"
|
23 | 23 | #include "core/framework/data_types_internal.h"
|
24 | 24 | #include "core/framework/error_code_helper.h"
|
| 25 | +#include "core/framework/plugin_ep_stream.h" |
25 | 26 | #include "core/framework/provider_options_utils.h"
|
26 | 27 | #include "core/framework/random_seed.h"
|
27 | 28 | #include "core/framework/sparse_tensor.h"
|
@@ -1587,6 +1588,18 @@ void addGlobalMethods(py::module& m) {
|
1587 | 1588 | },
|
1588 | 1589 | R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");
|
1589 | 1590 |
|
| 1591 | + m.def( |
| 1592 | + "copy_tensors", |
| 1593 | + [](const std::vector<const OrtValue*>& src, const std::vector<OrtValue*>& dest, py::object& py_arg) { |
| 1594 | + const OrtEnv* ort_env = GetOrtEnv(); |
| 1595 | + OrtSyncStream* stream = nullptr; |
| 1596 | + if (!py_arg.is_none()) { |
| 1597 | + stream = py_arg.cast<OrtSyncStream*>(); |
| 1598 | + } |
| 1599 | + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); |
| 1600 | + }, |
| 1601 | + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); |
| 1602 | + |
1590 | 1603 | #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
|
1591 | 1604 | m.def(
|
1592 | 1605 | "get_available_openvino_device_ids", []() -> std::vector<std::string> {
|
@@ -1788,6 +1801,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
|
1788 | 1801 | .value("CPU", OrtMemTypeCPU)
|
1789 | 1802 | .value("DEFAULT", OrtMemTypeDefault);
|
1790 | 1803 |
|
| 1804 | + py::enum_<OrtMemoryInfoDeviceType>(m, "OrtMemoryInfoDeviceType") |
| 1805 | + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) |
| 1806 | + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) |
| 1807 | + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) |
| 1808 | + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); |
| 1809 | + |
| 1810 | + py::enum_<OrtDeviceMemoryType>(m, "OrtDeviceMemoryType") |
| 1811 | + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) |
| 1812 | + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); |
| 1813 | + |
1791 | 1814 | py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc");
|
1792 | 1815 | device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::VendorId, OrtDevice::DeviceId>())
|
1793 | 1816 | .def(py::init([](OrtDevice::DeviceType type,
|
@@ -1816,6 +1839,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
|
1816 | 1839 | .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
|
1817 | 1840 | .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
|
1818 | 1841 | .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc")
|
| 1842 | + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") |
1819 | 1843 | // generic device types that are typically used with a vendor id.
|
1820 | 1844 | .def_static("cpu", []() { return OrtDevice::CPU; })
|
1821 | 1845 | .def_static("gpu", []() { return OrtDevice::GPU; })
|
@@ -1866,36 +1890,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
|
1866 | 1890 | },
|
1867 | 1891 | R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc");
|
1868 | 1892 |
|
| 1893 | + py::class_<OrtSyncStream> py_sync_stream(m, "OrtSyncStream", |
| 1894 | + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); |
| 1895 | + |
1869 | 1896 | py::class_<OrtEpDevice> py_ep_device(m, "OrtEpDevice",
|
1870 | 1897 | R"pbdoc(Represents a hardware device that an execution provider supports
|
1871 | 1898 | for model inference.)pbdoc");
|
1872 | 1899 | py_ep_device.def_property_readonly(
|
1873 | 1900 | "ep_name",
|
1874 |
| - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, |
| 1901 | + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, |
1875 | 1902 | R"pbdoc(The execution provider's name.)pbdoc")
|
1876 | 1903 | .def_property_readonly(
|
1877 | 1904 | "ep_vendor",
|
1878 |
| - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, |
| 1905 | + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, |
1879 | 1906 | R"pbdoc(The execution provider's vendor name.)pbdoc")
|
1880 | 1907 | .def_property_readonly(
|
1881 | 1908 | "ep_metadata",
|
1882 |
| - [](OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
| 1909 | + [](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
1883 | 1910 | return ep_device->ep_metadata.Entries();
|
1884 | 1911 | },
|
1885 | 1912 | R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc")
|
1886 | 1913 | .def_property_readonly(
|
1887 | 1914 | "ep_options",
|
1888 |
| - [](OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
| 1915 | + [](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
1889 | 1916 | return ep_device->ep_options.Entries();
|
1890 | 1917 | },
|
1891 | 1918 | R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc")
|
1892 | 1919 | .def_property_readonly(
|
1893 | 1920 | "device",
|
1894 |
| - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { |
| 1921 | + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { |
1895 | 1922 | return *ep_device->device;
|
1896 | 1923 | },
|
1897 | 1924 | R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc",
|
1898 |
| - py::return_value_policy::reference_internal); |
| 1925 | + py::return_value_policy::reference_internal) |
| 1926 | + .def( |
| 1927 | + "memory_info", |
| 1928 | + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { |
| 1929 | + Ort::ConstEpDevice ep_dev(ep_device); |
| 1930 | + return static_cast<const OrtMemoryInfo*>(ep_dev.GetMemoryInfo(memory_type)); |
| 1931 | + }, |
| 1932 | + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", |
| 1933 | + py::return_value_policy::reference_internal) |
| 1934 | + .def( |
| 1935 | + "create_sync_stream", |
| 1936 | + [](const OrtEpDevice* ep_device) -> std::unique_ptr<OrtSyncStream> { |
| 1937 | + Ort::ConstEpDevice ep_dev(ep_device); |
| 1938 | + Ort::SyncStream stream = ep_dev.CreateSyncStream(); |
| 1939 | + return std::unique_ptr<OrtSyncStream>(stream.release()); |
| 1940 | + }, |
| 1941 | + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); |
1899 | 1942 |
|
1900 | 1943 | py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
|
1901 | 1944 | // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
|
@@ -1941,25 +1984,28 @@ for model inference.)pbdoc");
|
1941 | 1984 | .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);
|
1942 | 1985 |
|
1943 | 1986 | py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
|
1944 |
| - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { |
1945 |
| - if (strcmp(name, onnxruntime::CPU) == 0) { |
1946 |
| - return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), mem_type); |
1947 |
| - } else if (strcmp(name, onnxruntime::CUDA) == 0) { |
1948 |
| - return std::make_unique<OrtMemoryInfo>( |
1949 |
| - onnxruntime::CUDA, type, |
1950 |
| - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, |
1951 |
| - static_cast<OrtDevice::DeviceId>(id)), |
1952 |
| - mem_type); |
1953 |
| - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { |
1954 |
| - return std::make_unique<OrtMemoryInfo>( |
1955 |
| - onnxruntime::CUDA_PINNED, type, |
1956 |
| - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, |
1957 |
| - static_cast<OrtDevice::DeviceId>(id)), |
1958 |
| - mem_type); |
1959 |
| - } else { |
1960 |
| - throw std::runtime_error("Specified device is not supported."); |
1961 |
| - } |
1962 |
| - })); |
| 1987 | + ort_memory_info_binding.def( |
| 1988 | + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { |
| 1989 | + Ort::MemoryInfo result(name, type, id, mem_type); |
| 1990 | + return std::unique_ptr<OrtMemoryInfo>(result.release()); |
| 1991 | + })) |
| 1992 | + .def_static( |
| 1993 | + "create_v2", |
| 1994 | + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, |
| 1995 | + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { |
| 1996 | + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); |
| 1997 | + return std::unique_ptr<OrtMemoryInfo>(result.release()); |
| 1998 | + }, |
| 1999 | + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") |
| 2000 | + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") |
| 2001 | + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") |
| 2002 | + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") |
| 2003 | + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") |
| 2004 | + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { |
| 2005 | + auto mem_type = mem_info->device.MemType(); |
| 2006 | + return (mem_type == OrtDevice::MemType::DEFAULT) ? |
| 2007 | + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") |
| 2008 | + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); |
1963 | 2009 |
|
1964 | 2010 | py::class_<PySessionOptions>
|
1965 | 2011 | sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
|
@@ -2699,6 +2745,33 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
2699 | 2745 | auto res = sess->GetSessionHandle()->GetModelMetadata();
|
2700 | 2746 | OrtPybindThrowIfError(res.first);
|
2701 | 2747 | return *(res.second); }, py::return_value_policy::reference_internal)
|
| 2748 | + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { |
| 2749 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2750 | + auto inputs_mem_info = session.GetMemoryInfoForInputs(); |
| 2751 | + py::list result; |
| 2752 | + for (const auto& info : inputs_mem_info) { |
| 2753 | + const auto* p_info = static_cast<const OrtMemoryInfo*>(info); |
| 2754 | + result.append(py::cast(p_info, py::return_value_policy::reference)); |
| 2755 | + } |
| 2756 | + return result; }) |
| 2757 | + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { |
| 2758 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2759 | + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); |
| 2760 | + py::list result; |
| 2761 | + for (const auto& info : outputs_mem_info) { |
| 2762 | + const auto* p_info = static_cast<const OrtMemoryInfo*>(info); |
| 2763 | + result.append(py::cast(p_info, py::return_value_policy::reference)); |
| 2764 | + } |
| 2765 | + return result; }) |
| 2766 | + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { |
| 2767 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2768 | + auto ep_devices = session.GetEpDeviceForInputs(); |
| 2769 | + py::list result; |
| 2770 | + for (const auto& device : ep_devices) { |
| 2771 | + const auto* p_device = static_cast<const OrtEpDevice*>(device); |
| 2772 | + result.append(py::cast(p_device, py::return_value_policy::reference)); |
| 2773 | + } |
| 2774 | + return result; }) |
2702 | 2775 | .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
|
2703 | 2776 |
|
2704 | 2777 | Status status;
|
|
0 commit comments