diff --git a/runtime/onert/api/python/include/nnfw_api_wrapper.h b/runtime/onert/api/python/include/nnfw_api_wrapper.h index 6db0c118ead..daa969a542f 100644 --- a/runtime/onert/api/python/include/nnfw_api_wrapper.h +++ b/runtime/onert/api/python/include/nnfw_api_wrapper.h @@ -35,6 +35,31 @@ namespace python namespace py = pybind11; +/** + * @brief Data type mapping between NNFW_TYPE and numpy dtype. + */ +struct datatype +{ +private: + NNFW_TYPE _nnfw_type; + py::dtype _py_dtype; + // The name of the dtype, e.g., "float32", "int32", etc. + // This is mainly for the __repr__ implementation. + const char *_name; + +public: + datatype() : datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32) {} + explicit datatype(NNFW_TYPE type); + + const char *name() const { return _name; } + ssize_t itemsize() const { return _py_dtype.itemsize(); } + NNFW_TYPE nnfw_type() const { return _nnfw_type; } + py::dtype py_dtype() const { return _py_dtype; } + + bool operator==(const datatype &other) const { return _nnfw_type == other._nnfw_type; } + bool operator!=(const datatype &other) const { return _nnfw_type != other._nnfw_type; } +}; + /** * @brief tensor info describes the type and shape of tensors * @@ -50,7 +75,7 @@ namespace py = pybind11; struct tensorinfo { /** The data type */ - std::string dtype; + datatype dtype; /** The number of dimensions (rank) */ int32_t rank; /** @@ -77,22 +102,6 @@ void ensure_status(NNFW_STATUS status); */ NNFW_LAYOUT getLayout(const char *layout = ""); -/** - * Convert the type with string to NNFW_TYPE - * - * @param[in] type type to be converted - * @return proper type if exists - */ -NNFW_TYPE getType(const char *type = ""); - -/** - * Convert the type with NNFW_TYPE to string - * - * @param[in] type type to be converted - * @return proper type - */ -const char *getStringType(NNFW_TYPE type); - /** * @brief Get the total number of elements in nnfw_tensorinfo->dims. * diff --git a/runtime/onert/api/python/onert/__init__.py b/runtime/onert/api/python/onert/__init__.py index 71b5512dd0f..c251310dae3 100644 --- a/runtime/onert/api/python/onert/__init__.py +++ b/runtime/onert/api/python/onert/__init__.py @@ -1,11 +1,12 @@ # Define the public API of the onert package -__all__ = ["infer", "tensorinfo", "experimental"] +__all__ = ["dtype", "infer", "tensorinfo", "experimental"] + +# Import and expose tensorinfo and tensor data types +from .native.libnnfw_api_pybind import dtype, tensorinfo +from .native.libnnfw_api_pybind.dtypes import * # Import and expose the infer module's functionalities from . import infer -# Import and expose tensorinfo -from .common import tensorinfo - # Import and expose the experimental module's functionalities from . import experimental diff --git a/runtime/onert/api/python/onert/experimental/train/dataloader.py b/runtime/onert/api/python/onert/experimental/train/dataloader.py index 9b27b3c4a74..b9ef6d0bf76 100644 --- a/runtime/onert/api/python/onert/experimental/train/dataloader.py +++ b/runtime/onert/api/python/onert/experimental/train/dataloader.py @@ -1,6 +1,7 @@ import os import numpy as np from typing import List, Tuple, Union, Optional, Any, Iterator +import onert class DataLoader: @@ -14,7 +15,7 @@ def __init__(self, batch_size: int, input_shape: Optional[Tuple[int, ...]] = None, expected_shape: Optional[Tuple[int, ...]] = None, - dtype: Any = np.float32) -> None: + dtype: Any = onert.float32) -> None: """ Initialize the DataLoader. @@ -28,7 +29,7 @@ def __init__(self, batch_size (int): Number of samples per batch. input_shape (tuple[int, ...], optional): Shape of the input data if raw format is used. expected_shape (tuple[int, ...], optional): Shape of the expected data if raw format is used. - dtype (type, optional): Data type of the raw file (default: np.float32). + dtype (type, optional): Data type of the raw file (default: onert.float32). """ self.batch_size: int = batch_size self.inputs: List[np.ndarray] = self._process_dataset(input_dataset, input_shape, @@ -49,7 +50,7 @@ def __init__(self, def _process_dataset(self, data: Union[List[np.ndarray], np.ndarray, str], shape: Optional[Tuple[int, ...]], - dtype: Any = np.float32) -> List[np.ndarray]: + dtype: Any = onert.float32) -> List[np.ndarray]: """ Process a dataset or file path. @@ -83,14 +84,14 @@ def _process_dataset(self, def _load_data(self, file_path: str, shape: Optional[Tuple[int, ...]], - dtype: Any = np.float32) -> np.ndarray: + dtype: Any = onert.float32) -> np.ndarray: """ Load data from a file, supporting both .npy and raw formats. Args: file_path (str): Path to the file to load. shape (tuple[int, ...], optional): Shape of the data if raw format is used. - dtype (type, optional): Data type of the raw file (default: np.float32). + dtype (type, optional): Data type of the raw file (default: onert.float32). Returns: np.ndarray: Loaded data as a NumPy array. diff --git a/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc b/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc index 1350dda1f13..967d1cde500 100644 --- a/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc +++ b/runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc @@ -18,6 +18,8 @@ #include "nnfw_api_wrapper.h" +#include + namespace onert::api::python { @@ -26,6 +28,38 @@ namespace py = pybind11; // Bind the `tensorinfo` class void bind_tensorinfo(py::module_ &m) { + + static const datatype dtypes[] = { + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_INT32), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_INT64), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED), + datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED), + }; + + // Export dedicated OneRT type for tensor types. The presence of the "dtype" + // property allows this type to be used directly with numpy, e.g.: + // >>> np.array([3, 6, 3], dtype=onert.float32) + py::class_(m, "dtype", "Defines the type of the OneRT tensor.", py::module_local()) + .def(py::self == py::self) + .def(py::self != py::self) + .def("__repr__", [](const datatype &dt) { return std::string("onert.") + dt.name(); }) + .def_property_readonly( + "name", [](const datatype &dt) { return dt.name(); }, "The name of the data type.") + .def_property_readonly( + "dtype", [](const datatype &dt) { return dt.py_dtype(); }, "A corresponding numpy data type.") + .def_property_readonly( + "itemsize", [](const datatype &dt) { return dt.itemsize(); }, + "The element size of this data-type object."); + + // Export OneRT dtypes in a submodule, so we can batch import them + auto m_dtypes = m.def_submodule("dtypes", "OneRT tensor data types"); + for (const auto &dt : dtypes) + m_dtypes.attr(dt.name()) = dt; + py::class_(m, "tensorinfo", "tensorinfo describes the type and shape of tensors", py::module_local()) .def(py::init<>(), "The constructor of tensorinfo") diff --git a/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc b/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc index 2d660223d59..cfcb00c7916 100644 --- a/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc +++ b/runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc @@ -51,57 +51,54 @@ NNFW_LAYOUT getLayout(const char *layout) { if (std::strcmp(layout, "NCHW") == 0) return NNFW_LAYOUT::NNFW_LAYOUT_CHANNELS_FIRST; - else if (std::strcmp(layout, "NHWC") == 0) + if (std::strcmp(layout, "NHWC") == 0) return NNFW_LAYOUT::NNFW_LAYOUT_CHANNELS_LAST; - else if (std::strcmp(layout, "NONE") == 0) + if (std::strcmp(layout, "NONE") == 0) return NNFW_LAYOUT::NNFW_LAYOUT_NONE; - else - throw NnfwError(std::string("Unknown layout type: '") + layout + "'"); + throw NnfwError(std::string("Unknown layout type: '") + layout + "'"); } -NNFW_TYPE getType(const char *type) -{ - if (std::strcmp(type, "float32") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32; - else if (std::strcmp(type, "int32") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_INT32; - else if (std::strcmp(type, "bool") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8; - else if (std::strcmp(type, "bool") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL; - else if (std::strcmp(type, "int64") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_INT64; - else if (std::strcmp(type, "int8") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED; - else if (std::strcmp(type, "int16") == 0) - return NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED; - else - throw NnfwError(std::string("Cannot convert string to NNFW_TYPE: '") + type + "'"); -} - -const char *getStringType(NNFW_TYPE type) +datatype::datatype(NNFW_TYPE type) : _nnfw_type(type) { switch (type) { case NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32: - return "float32"; + _py_dtype = py::dtype("float32"); + _name = "float32"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_INT32: - return "int32"; + _py_dtype = py::dtype("int32"); + _name = "int32"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM: + _py_dtype = py::dtype("uint8"); + _name = "quint8"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8: - return "uint8"; + _py_dtype = py::dtype("uint8"); + _name = "uint8"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL: - return "bool"; + _py_dtype = py::dtype("bool"); + _name = "bool"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_INT64: - return "int64"; + _py_dtype = py::dtype("int64"); + _name = "int64"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED: - return "int8"; + _py_dtype = py::dtype("int8"); + _name = "qint8"; + return; case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED: - return "int16"; - default: - throw NnfwError(std::string("Cannot convert NNFW_TYPE enum to string (value=") + - std::to_string(static_cast(type)) + ")"); + _py_dtype = py::dtype("int16"); + _name = "qint16sym"; + return; } + // This code should not be reached because compiler will generate a warning + // if some type is not handled in the switch block above. + throw NnfwError(std::string("Cannot convert NNFW_TYPE enum to onert.dtype (value=") + + std::to_string(static_cast(type)) + ")"); } uint64_t num_elems(const nnfw_tensorinfo *tensor_info) @@ -153,10 +150,11 @@ void NNFW_SESSION::close_session() ensure_status(nnfw_close_session(this->session)); this->session = nullptr; } + void NNFW_SESSION::set_input_tensorinfo(uint32_t index, const tensorinfo *tensor_info) { nnfw_tensorinfo ti; - ti.dtype = getType(tensor_info->dtype.c_str()); + ti.dtype = tensor_info->dtype.nnfw_type(); ti.rank = tensor_info->rank; for (int i = 0; i < NNFW_MAX_RANK; i++) { @@ -187,12 +185,13 @@ void NNFW_SESSION::set_input_layout(uint32_t index, const char *layout) NNFW_LAYOUT nnfw_layout = getLayout(layout); ensure_status(nnfw_set_input_layout(session, index, nnfw_layout)); } + tensorinfo NNFW_SESSION::input_tensorinfo(uint32_t index) { nnfw_tensorinfo tensor_info = nnfw_tensorinfo(); ensure_status(nnfw_input_tensorinfo(session, index, &tensor_info)); tensorinfo ti; - ti.dtype = getStringType(tensor_info.dtype); + ti.dtype = datatype(tensor_info.dtype); ti.rank = tensor_info.rank; for (int i = 0; i < NNFW_MAX_RANK; i++) { @@ -200,12 +199,13 @@ tensorinfo NNFW_SESSION::input_tensorinfo(uint32_t index) } return ti; } + tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index) { nnfw_tensorinfo tensor_info = nnfw_tensorinfo(); ensure_status(nnfw_output_tensorinfo(session, index, &tensor_info)); tensorinfo ti; - ti.dtype = getStringType(tensor_info.dtype); + ti.dtype = datatype(tensor_info.dtype); ti.rank = tensor_info.rank; for (int i = 0; i < NNFW_MAX_RANK; i++) { @@ -234,13 +234,10 @@ py::array NNFW_SESSION::get_output(uint32_t index) num_elements *= static_cast(out_info.dims[i]); } + const auto dtype = datatype(out_info.dtype); // Wrap the raw buffer in a numpy array; - auto np = py::module_::import("numpy"); - py::dtype dt = np.attr("dtype")(py::str(getStringType(out_info.dtype))).cast(); - size_t itemsize = dt.attr("itemsize").cast(); - - py::array arr(dt, shape); - std::memcpy(arr.mutable_data(), out_buffer, num_elements * itemsize); + py::array arr(dtype.py_dtype(), shape); + std::memcpy(arr.mutable_data(), out_buffer, num_elements * dtype.itemsize()); arr.attr("flags").attr("writeable") = false; return arr;