Skip to content

Commit c36e868

Browse files
committed
[onert] Export dedicated OneRT data types to Python
NNFW types and numpy data types do not map one to one because NNFW has quantized types represented as uint8 or int16. Because of that it should be more efficient to export custom data type object which will map these two types. Additionally, for convenience, dedicated types are exported in the top-level onert Python module, so one can use them as follows: > import numpy as np, onert > np.array([2, 42, 42], dtype=onert.float32) ONE-DCO-1.0-Signed-off-by: Arkadiusz Bokowy <a.bokowy@samsung.com>
1 parent 1de77d5 commit c36e868

File tree

6 files changed

+101
-71
lines changed

6 files changed

+101
-71
lines changed

runtime/onert/api/python/include/nnfw_api_wrapper.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ namespace python
3333

3434
namespace py = pybind11;
3535

36+
/**
37+
* @brief Data type mapping between NNFW_TYPE and numpy dtype.
38+
*/
39+
struct dtype
40+
{
41+
NNFW_TYPE nnfw_type;
42+
py::dtype py_dtype;
43+
// The name of the dtype, e.g., "float32", "int32", etc.
44+
// This is mainly for the __repr__ implementation.
45+
const char *name;
46+
47+
dtype() = default;
48+
explicit dtype(NNFW_TYPE type);
49+
50+
bool operator==(const struct dtype &other) const { return nnfw_type == other.nnfw_type; }
51+
bool operator!=(const struct dtype &other) const { return nnfw_type != other.nnfw_type; }
52+
};
53+
3654
/**
3755
* @brief tensor info describes the type and shape of tensors
3856
*
@@ -48,7 +66,7 @@ namespace py = pybind11;
4866
struct tensorinfo
4967
{
5068
/** The data type */
51-
const char *dtype;
69+
struct dtype dtype;
5270
/** The number of dimensions (rank) */
5371
int32_t rank;
5472
/**
@@ -75,22 +93,6 @@ void ensure_status(NNFW_STATUS status);
7593
*/
7694
NNFW_LAYOUT getLayout(const char *layout = "");
7795

78-
/**
79-
* Convert the type with string to NNFW_TYPE
80-
*
81-
* @param[in] type type to be converted
82-
* @return proper type if exists
83-
*/
84-
NNFW_TYPE getType(const char *type = "");
85-
86-
/**
87-
* Convert the type with NNFW_TYPE to string
88-
*
89-
* @param[in] type type to be converted
90-
* @return proper type
91-
*/
92-
const char *getStringType(NNFW_TYPE type);
93-
9496
/**
9597
* @brief Get the total number of elements in nnfw_tensorinfo->dims.
9698
*
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Define the public API of the onert package
2-
__all__ = ["infer", "tensorinfo", "experimental"]
2+
__all__ = ["dtype", "infer", "tensorinfo", "experimental"]
3+
4+
# Import and expose tensorinfo and tensor data types
5+
from .native.libnnfw_api_pybind import dtype, tensorinfo
6+
from .native.libnnfw_api_pybind.dtypes import *
37

48
# Import and expose the infer module's functionalities
59
from . import infer
610

7-
# Import and expose tensorinfo
8-
from .common import tensorinfo
9-
1011
# Import and expose the experimental module's functionalities
1112
from . import experimental

runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include "nnfw_api_wrapper.h"
2020

21+
#include <pybind11/operators.h>
22+
2123
namespace onert::api::python
2224
{
2325

@@ -26,6 +28,36 @@ namespace py = pybind11;
2628
// Bind the `tensorinfo` class
2729
void bind_tensorinfo(py::module_ &m)
2830
{
31+
32+
static const struct dtype dtypes[] = {
33+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32),
34+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_INT32),
35+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM),
36+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8),
37+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL),
38+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_INT64),
39+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED),
40+
dtype(NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED),
41+
};
42+
43+
// Export dedicated OneRT type for tensor types. The presence of the "dtype"
44+
// property allows this type to be used directly with numpy, e.g.:
45+
// >>> np.array([3, 6, 3], dtype=onert.float32)
46+
py::class_<dtype>(m, "dtype", "Defines the type of the OneRT tensor.", py::module_local())
47+
.def(py::self == py::self)
48+
.def(py::self != py::self)
49+
.def("__repr__", [](const dtype &dt) { return std::string("onert.") + dt.name; })
50+
.def_readonly("name", &dtype::name, "The name of the data type.")
51+
.def_readonly("dtype", &dtype::py_dtype, "A corresponding numpy data type.")
52+
.def_property_readonly(
53+
"itemsize", [](const dtype &dt) { return dt.py_dtype.itemsize(); },
54+
"The element size of this data-type object.");
55+
56+
// Export OneRT dtypes in a submodule, so we can batch import them
57+
auto m_dtypes = m.def_submodule("dtypes", "OneRT tensor data types");
58+
for (const auto &dt : dtypes)
59+
m_dtypes.attr(dt.name) = dt;
60+
2961
py::class_<tensorinfo>(m, "tensorinfo", "tensorinfo describes the type and shape of tensors",
3062
py::module_local())
3163
.def(py::init<>(), "The constructor of tensorinfo")

runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -51,57 +51,54 @@ NNFW_LAYOUT getLayout(const char *layout)
5151
{
5252
if (std::strcmp(layout, "NCHW") == 0)
5353
return NNFW_LAYOUT::NNFW_LAYOUT_CHANNELS_FIRST;
54-
else if (std::strcmp(layout, "NHWC") == 0)
54+
if (std::strcmp(layout, "NHWC") == 0)
5555
return NNFW_LAYOUT::NNFW_LAYOUT_CHANNELS_LAST;
56-
else if (std::strcmp(layout, "NONE") == 0)
56+
if (std::strcmp(layout, "NONE") == 0)
5757
return NNFW_LAYOUT::NNFW_LAYOUT_NONE;
58-
else
59-
throw NnfwError(std::string("Unknown layout type: '") + layout + "'");
58+
throw NnfwError(std::string("Unknown layout type: '") + layout + "'");
6059
}
6160

62-
NNFW_TYPE getType(const char *type)
63-
{
64-
if (std::strcmp(type, "float32") == 0)
65-
return NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32;
66-
else if (std::strcmp(type, "int32") == 0)
67-
return NNFW_TYPE::NNFW_TYPE_TENSOR_INT32;
68-
else if (std::strcmp(type, "bool") == 0)
69-
return NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8;
70-
else if (std::strcmp(type, "bool") == 0)
71-
return NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL;
72-
else if (std::strcmp(type, "int64") == 0)
73-
return NNFW_TYPE::NNFW_TYPE_TENSOR_INT64;
74-
else if (std::strcmp(type, "int8") == 0)
75-
return NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED;
76-
else if (std::strcmp(type, "int16") == 0)
77-
return NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED;
78-
else
79-
throw NnfwError(std::string("Cannot convert string to NNFW_TYPE: '") + type + "'");
80-
}
81-
82-
const char *getStringType(NNFW_TYPE type)
61+
dtype::dtype(NNFW_TYPE type) : nnfw_type(type)
8362
{
8463
switch (type)
8564
{
8665
case NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32:
87-
return "float32";
66+
py_dtype = py::dtype("float32");
67+
name = "float32";
68+
return;
8869
case NNFW_TYPE::NNFW_TYPE_TENSOR_INT32:
89-
return "int32";
70+
py_dtype = py::dtype("int32");
71+
name = "int32";
72+
return;
9073
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM:
74+
py_dtype = py::dtype("uint8");
75+
name = "quint8";
76+
return;
9177
case NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8:
92-
return "uint8";
78+
py_dtype = py::dtype("uint8");
79+
name = "uint8";
80+
return;
9381
case NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL:
94-
return "bool";
82+
py_dtype = py::dtype("bool");
83+
name = "bool";
84+
return;
9585
case NNFW_TYPE::NNFW_TYPE_TENSOR_INT64:
96-
return "int64";
86+
py_dtype = py::dtype("int64");
87+
name = "int64";
88+
return;
9789
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
98-
return "int8";
90+
py_dtype = py::dtype("int8");
91+
name = "qint8";
92+
return;
9993
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
100-
return "int16";
101-
default:
102-
throw NnfwError(std::string("Cannot convert NNFW_TYPE enum to string (value=") +
103-
std::to_string(static_cast<int>(type)) + ")");
94+
py_dtype = py::dtype("int16");
95+
name = "qint16sym";
96+
return;
10497
}
98+
// This code should not be reached because compiler will generate a warning
99+
// if some type is not handled in the switch block above.
100+
throw NnfwError(std::string("Cannot convert NNFW_TYPE enum to onert.dtype (value=") +
101+
std::to_string(static_cast<int>(type)) + ")");
105102
}
106103

107104
uint64_t num_elems(const nnfw_tensorinfo *tensor_info)
@@ -153,10 +150,11 @@ void NNFW_SESSION::close_session()
153150
ensure_status(nnfw_close_session(this->session));
154151
this->session = nullptr;
155152
}
153+
156154
void NNFW_SESSION::set_input_tensorinfo(uint32_t index, const tensorinfo *tensor_info)
157155
{
158156
nnfw_tensorinfo ti;
159-
ti.dtype = getType(tensor_info->dtype);
157+
ti.dtype = tensor_info->dtype.nnfw_type;
160158
ti.rank = tensor_info->rank;
161159
for (int i = 0; i < NNFW_MAX_RANK; i++)
162160
{
@@ -187,25 +185,27 @@ void NNFW_SESSION::set_input_layout(uint32_t index, const char *layout)
187185
NNFW_LAYOUT nnfw_layout = getLayout(layout);
188186
ensure_status(nnfw_set_input_layout(session, index, nnfw_layout));
189187
}
188+
190189
tensorinfo NNFW_SESSION::input_tensorinfo(uint32_t index)
191190
{
192191
nnfw_tensorinfo tensor_info = nnfw_tensorinfo();
193192
ensure_status(nnfw_input_tensorinfo(session, index, &tensor_info));
194193
tensorinfo ti;
195-
ti.dtype = getStringType(tensor_info.dtype);
194+
ti.dtype = dtype(tensor_info.dtype);
196195
ti.rank = tensor_info.rank;
197196
for (int i = 0; i < NNFW_MAX_RANK; i++)
198197
{
199198
ti.dims[i] = tensor_info.dims[i];
200199
}
201200
return ti;
202201
}
202+
203203
tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index)
204204
{
205205
nnfw_tensorinfo tensor_info = nnfw_tensorinfo();
206206
ensure_status(nnfw_output_tensorinfo(session, index, &tensor_info));
207207
tensorinfo ti;
208-
ti.dtype = getStringType(tensor_info.dtype);
208+
ti.dtype = dtype(tensor_info.dtype);
209209
ti.rank = tensor_info.rank;
210210
for (int i = 0; i < NNFW_MAX_RANK; i++)
211211
{
@@ -234,13 +234,10 @@ py::array NNFW_SESSION::get_output(uint32_t index)
234234
num_elements *= static_cast<size_t>(out_info.dims[i]);
235235
}
236236

237+
const auto type = dtype(out_info.dtype);
237238
// Wrap the raw buffer in a numpy array;
238-
auto np = py::module_::import("numpy");
239-
py::dtype dt = np.attr("dtype")(py::str(getStringType(out_info.dtype))).cast<py::dtype>();
240-
size_t itemsize = dt.attr("itemsize").cast<size_t>();
241-
242-
py::array arr(dt, shape);
243-
std::memcpy(arr.mutable_data(), out_buffer, num_elements * itemsize);
239+
py::array arr(type.py_dtype, shape);
240+
std::memcpy(arr.mutable_data(), out_buffer, num_elements * type.py_dtype.itemsize());
244241
arr.attr("flags").attr("writeable") = false;
245242

246243
return arr;

runtime/onert/sample/minimal-python/inference_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import psutil
44
import os
5-
from typing import List
5+
from typing import List, Optional
66
from onert import infer, tensorinfo
77

88

@@ -45,8 +45,8 @@ def get_validated_input_tensorinfos(sess: infer.session,
4545
return updated_infos
4646

4747

48-
def benchmark_inference(nnpackage_path: str, backends: str, input_shapes: List[List[int]],
49-
repeat: int):
48+
def benchmark_inference(nnpackage_path: str, backends: str,
49+
input_shapes: Optional[List[List[int]]], repeat: int):
5050
mem_before_kb = get_memory_usage_mb() * 1024
5151

5252
sess = infer.session(path=nnpackage_path, backends=backends)

runtime/onert/sample/minimal-python/minimal.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ def main(nnpackage_path, backends="cpu"):
1212
input_infos = session.get_inputs_tensorinfo()
1313
dummy_inputs = []
1414
for info in input_infos:
15-
# Retrieve the dimensions list from tensorinfo property.
16-
dims = list(info.dims)
1715
# Build the shape tuple from tensorinfo dimensions.
18-
shape = tuple(dims[:info.rank])
16+
shape = tuple(info.dims[:info.rank])
1917
# Create a dummy numpy array filled with zeros.
2018
dummy_inputs.append(np.zeros(shape, dtype=info.dtype))
2119

0 commit comments

Comments
 (0)