Skip to content

Commit 1ed7bd5

Browse files
committed
Protect internal fields to make dtype immutable
1 parent e6257a5 commit 1ed7bd5

File tree

3 files changed

+40
-31
lines changed

3 files changed

+40
-31
lines changed

runtime/onert/api/python/include/nnfw_api_wrapper.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,24 @@ namespace py = pybind11;
4040
*/
4141
struct datatype
4242
{
43-
NNFW_TYPE nnfw_type;
44-
py::dtype py_dtype;
43+
private:
44+
NNFW_TYPE _nnfw_type;
45+
py::dtype _py_dtype;
4546
// The name of the dtype, e.g., "float32", "int32", etc.
4647
// This is mainly for the __repr__ implementation.
47-
const char *name;
48+
const char *_name;
4849

49-
datatype() = default;
50+
public:
51+
datatype() : datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32) {}
5052
explicit datatype(NNFW_TYPE type);
5153

52-
bool operator==(const datatype &other) const { return nnfw_type == other.nnfw_type; }
53-
bool operator!=(const datatype &other) const { return nnfw_type != other.nnfw_type; }
54+
const char *name() const { return _name; }
55+
ssize_t itemsize() const { return _py_dtype.itemsize(); }
56+
NNFW_TYPE nnfw_type() const { return _nnfw_type; }
57+
py::dtype py_dtype() const { return _py_dtype; }
58+
59+
bool operator==(const datatype &other) const { return _nnfw_type == other._nnfw_type; }
60+
bool operator!=(const datatype &other) const { return _nnfw_type != other._nnfw_type; }
5461
};
5562

5663
/**

runtime/onert/api/python/src/bindings/nnfw_tensorinfo_bindings.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,19 @@ void bind_tensorinfo(py::module_ &m)
4646
py::class_<datatype>(m, "dtype", "Defines the type of the OneRT tensor.", py::module_local())
4747
.def(py::self == py::self)
4848
.def(py::self != py::self)
49-
.def("__repr__", [](const datatype &dt) { return std::string("onert.") + dt.name; })
50-
.def_readonly("name", &datatype::name, "The name of the data type.")
51-
.def_readonly("dtype", &datatype::py_dtype, "A corresponding numpy data type.")
49+
.def("__repr__", [](const datatype &dt) { return std::string("onert.") + dt.name(); })
5250
.def_property_readonly(
53-
"itemsize", [](const datatype &dt) { return dt.py_dtype.itemsize(); },
51+
"name", [](const datatype &dt) { return dt.name(); }, "The name of the data type.")
52+
.def_property_readonly(
53+
"dtype", [](const datatype &dt) { return dt.py_dtype(); }, "A corresponding numpy data type.")
54+
.def_property_readonly(
55+
"itemsize", [](const datatype &dt) { return dt.itemsize(); },
5456
"The element size of this data-type object.");
5557

5658
// Export OneRT dtypes in a submodule, so we can batch import them
5759
auto m_dtypes = m.def_submodule("dtypes", "OneRT tensor data types");
5860
for (const auto &dt : dtypes)
59-
m_dtypes.attr(dt.name) = dt;
61+
m_dtypes.attr(dt.name()) = dt;
6062

6163
py::class_<tensorinfo>(m, "tensorinfo", "tensorinfo describes the type and shape of tensors",
6264
py::module_local())

runtime/onert/api/python/src/wrapper/nnfw_api_wrapper.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,41 +58,41 @@ NNFW_LAYOUT getLayout(const char *layout)
5858
throw NnfwError(std::string("Unknown layout type: '") + layout + "'");
5959
}
6060

61-
datatype::datatype(NNFW_TYPE type) : nnfw_type(type)
61+
datatype::datatype(NNFW_TYPE type) : _nnfw_type(type)
6262
{
6363
switch (type)
6464
{
6565
case NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32:
66-
py_dtype = py::dtype("float32");
67-
name = "float32";
66+
_py_dtype = py::dtype("float32");
67+
_name = "float32";
6868
return;
6969
case NNFW_TYPE::NNFW_TYPE_TENSOR_INT32:
70-
py_dtype = py::dtype("int32");
71-
name = "int32";
70+
_py_dtype = py::dtype("int32");
71+
_name = "int32";
7272
return;
7373
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM:
74-
py_dtype = py::dtype("uint8");
75-
name = "quint8";
74+
_py_dtype = py::dtype("uint8");
75+
_name = "quint8";
7676
return;
7777
case NNFW_TYPE::NNFW_TYPE_TENSOR_UINT8:
78-
py_dtype = py::dtype("uint8");
79-
name = "uint8";
78+
_py_dtype = py::dtype("uint8");
79+
_name = "uint8";
8080
return;
8181
case NNFW_TYPE::NNFW_TYPE_TENSOR_BOOL:
82-
py_dtype = py::dtype("bool");
83-
name = "bool";
82+
_py_dtype = py::dtype("bool");
83+
_name = "bool";
8484
return;
8585
case NNFW_TYPE::NNFW_TYPE_TENSOR_INT64:
86-
py_dtype = py::dtype("int64");
87-
name = "int64";
86+
_py_dtype = py::dtype("int64");
87+
_name = "int64";
8888
return;
8989
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT8_ASYMM_SIGNED:
90-
py_dtype = py::dtype("int8");
91-
name = "qint8";
90+
_py_dtype = py::dtype("int8");
91+
_name = "qint8";
9292
return;
9393
case NNFW_TYPE::NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
94-
py_dtype = py::dtype("int16");
95-
name = "qint16sym";
94+
_py_dtype = py::dtype("int16");
95+
_name = "qint16sym";
9696
return;
9797
}
9898
// This code should not be reached because compiler will generate a warning
@@ -154,7 +154,7 @@ void NNFW_SESSION::close_session()
154154
void NNFW_SESSION::set_input_tensorinfo(uint32_t index, const tensorinfo *tensor_info)
155155
{
156156
nnfw_tensorinfo ti;
157-
ti.dtype = tensor_info->dtype.nnfw_type;
157+
ti.dtype = tensor_info->dtype.nnfw_type();
158158
ti.rank = tensor_info->rank;
159159
for (int i = 0; i < NNFW_MAX_RANK; i++)
160160
{
@@ -236,8 +236,8 @@ py::array NNFW_SESSION::get_output(uint32_t index)
236236

237237
const auto dtype = datatype(out_info.dtype);
238238
// Wrap the raw buffer in a numpy array;
239-
py::array arr(dtype.py_dtype, shape);
240-
std::memcpy(arr.mutable_data(), out_buffer, num_elements * dtype.py_dtype.itemsize());
239+
py::array arr(dtype.py_dtype(), shape);
240+
std::memcpy(arr.mutable_data(), out_buffer, num_elements * dtype.itemsize());
241241
arr.attr("flags").attr("writeable") = false;
242242

243243
return arr;

0 commit comments

Comments
 (0)