1616#include " core/framework/ort_value.h"
1717#include " core/session/inference_session.h"
1818
19+ #include < variant>
20+
1921PYBIND11_MAKE_OPAQUE (std::vector<OrtValue>);
2022
2123namespace onnxruntime {
@@ -40,6 +42,8 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type);
4042
4143using MemCpyFunc = void (*)(void *, const void *, size_t );
4244
45+ using DataTransferAlternative = std::variant<const DataTransferManager*, MemCpyFunc>;
46+
4347void CpuToCpuMemCpy (void *, const void *, size_t );
4448
4549void CopyDataToTensor (const pybind11::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);
@@ -117,9 +121,42 @@ void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const
117121 const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue,
118122 bool accept_only_numpy_array = false , bool use_numpy_data_memory = true , MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);
119123
120- void GetPyObjFromTensor (const Tensor& rtensor, pybind11::object& obj,
121- const DataTransferManager* data_transfer_manager = nullptr ,
122- const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr );
124+ pybind11::object GetPyObjFromTensor (const OrtValue& rtensor,
125+ const DataTransferManager* data_transfer_manager = nullptr ,
126+ const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr );
127+
128+ // The below two functions are used to convert OrtValue to numpy arrays
129+
130+ // / <summary>
131+ // / This function operates on string tensors. Strings are always
132+ // / copied to python and converted to UTF-16/UCS-4/32 depending on the platform.
133+ // / This is accomplished using py::cast()
134+ // /
135+ // / It is an error to pass a non-tensor or a non-string tensor to this function.
136+ // / </summary>
137+ // / <param name="tensor">Tensor that contains strings</param>
138+ // / <returns>py::array object</returns>
139+ pybind11::array StringTensorToNumpyArray (const Tensor& tensor);
140+
141+ // / <summary>
142+ // / Creates a numpy array with shape over OrtValue memory. Numpy array
143+ // / does not own the memory, but it holds a copy or OrtValue in a py::capsule.
144+ // / OrtValue is destroyed when the numpy array is garbage collected.
145+ // / This is used when the OrtValue memory is on CPU.
146+ // / </summary>
147+ // / <param name="ort_value">OrtValue with data</param>
148+ // / <returns>numpy array</returns>
149+ pybind11::array PrimitiveTensorToNumpyOverOrtValue (const OrtValue& ort_value);
150+
151+ // / <summary>
152+ // / Creates a numpy array with shape with a copy of OrtValue data.
153+ // / This function is used when the OrtValue memory is not on CPU.
154+ // / </summary>
155+ // / <param name="ort_value">Source memory that is not on CPU.</param>
156+ // / <param name="data_transfer">a variant encapsulating alternatives for copying data</param>
157+ // / <returns></returns>
158+ pybind11::array PrimitiveTensorToNumpyFromDevice (const OrtValue& ort_value,
159+ const DataTransferAlternative& data_transfer);
123160
124161template <class T >
125162struct DecRefFn {
0 commit comments