-
Notifications
You must be signed in to change notification settings - Fork 178
Expand file tree
/
Copy pathnnfw_api_wrapper.h
More file actions
254 lines (219 loc) · 8.02 KB
/
nnfw_api_wrapper.h
File metadata and controls
254 lines (219 loc) · 8.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
#include <string>
#include "nnfw.h"
#include "nnfw_experimental.h"
#include "nnfw_internal.h"
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
namespace onert
{
namespace api
{
namespace python
{
namespace py = pybind11;
/**
* @brief Data type mapping between NNFW_TYPE and numpy dtype.
*/
struct datatype
{
private:
NNFW_TYPE _nnfw_type;
py::dtype _py_dtype;
// The name of the dtype, e.g., "float32", "int32", etc.
// This is mainly for the __repr__ implementation.
const char *_name;
public:
datatype() : datatype(NNFW_TYPE::NNFW_TYPE_TENSOR_FLOAT32) {}
explicit datatype(NNFW_TYPE type);
const char *name() const { return _name; }
ssize_t itemsize() const { return _py_dtype.itemsize(); }
NNFW_TYPE nnfw_type() const { return _nnfw_type; }
py::dtype py_dtype() const { return _py_dtype; }
bool operator==(const datatype &other) const { return _nnfw_type == other._nnfw_type; }
bool operator!=(const datatype &other) const { return _nnfw_type != other._nnfw_type; }
};
/**
* @brief tensor info describes the type and shape of tensors
*
* This structure is used to describe input and output tensors.
* Application can get input and output tensor type and shape described in model by using
* {@link input_tensorinfo} and {@link output_tensorinfo}
*
* Maximum rank is 6 (NNFW_MAX_RANK).
* And tensor's dimension value is filled in 'dims' field from index 0.
* For example, if tensor's rank is 4,
* application can get dimension value from dims[0], dims[1], dims[2], and dims[3]
*/
struct tensorinfo
{
/** The data type */
datatype dtype;
/** The number of dimensions (rank) */
int32_t rank;
/**
* The dimension of tensor.
* Maximum rank is 6 (NNFW_MAX_RANK).
*/
int32_t dims[NNFW_MAX_RANK];
};
/**
* @brief Handle errors with NNFW_STATUS in API functions.
*
* This only handles NNFW_STATUS errors.
*
* @param[in] status The status returned by API functions
*/
void ensure_status(NNFW_STATUS status);
/**
* Convert the layout with string to NNFW_LAYOUT
*
* @param[in] layout layout to be converted
* @return proper layout if exists
*/
NNFW_LAYOUT getLayout(const char *layout = "");
/**
* @brief Get the total number of elements in nnfw_tensorinfo->dims.
*
* This function is called to set the size of the input, output array.
*
* @param[in] tensor_info Tensor info (shape, type, etc)
* @return total number of elements
*/
uint64_t num_elems(const nnfw_tensorinfo *tensor_info);
/**
* @brief Get nnfw_tensorinfo->dims.
*
* This function is called to get dimension array of tensorinfo.
*
* @param[in] tensor_info Tensor info (shape, type, etc)
* @return python list of dims
*/
py::list get_dims(const tensorinfo &tensor_info);
/**
* @brief Set nnfw_tensorinfo->dims.
*
* This function is called to set dimension array of tensorinfo.
*
* @param[in] tensor_info Tensor info (shape, type, etc)
* @param[in] array array to set dimension
*/
void set_dims(tensorinfo &tensor_info, const py::list &array);
class NNFW_SESSION
{
private:
nnfw_session *session;
public:
NNFW_SESSION(const char *package_file_path, const char *backends);
~NNFW_SESSION();
void close_session();
void set_input_tensorinfo(uint32_t index, const tensorinfo *tensor_info);
void prepare();
void run();
void run_async();
void wait();
/**
* @brief process input array according to data type of numpy array sent by Python
* (int, float, uint8_t, bool, int64_t, int8_t, int16_t)
*/
template <typename T> void set_input(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_input_tensorinfo(this->session, index, &tensor_info);
NNFW_TYPE type = tensor_info.dtype;
uint32_t input_elements = num_elems(&tensor_info);
size_t length = sizeof(T) * input_elements;
ensure_status(nnfw_set_input(session, index, type, buffer.request().ptr, length));
}
/**
* @brief process output array according to data type of numpy array sent by Python
* (int, float, uint8_t, bool, int64_t, int8_t, int16_t)
*/
template <typename T> void set_output(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_output_tensorinfo(this->session, index, &tensor_info);
NNFW_TYPE type = tensor_info.dtype;
uint32_t output_elements = num_elems(&tensor_info);
size_t length = sizeof(T) * output_elements;
ensure_status(nnfw_set_output(session, index, type, buffer.request().ptr, length));
}
uint32_t input_size();
uint32_t output_size();
// process the input layout by receiving a string from Python instead of NNFW_LAYOUT
void set_input_layout(uint32_t index, const char *layout);
// process the output layout by receiving a string from Python instead of NNFW_LAYOUT
void set_output_layout(uint32_t index, const char *layout);
tensorinfo input_tensorinfo(uint32_t index);
tensorinfo output_tensorinfo(uint32_t index);
//////////////////////////////////////////////
// Internal APIs
//////////////////////////////////////////////
py::array get_output(uint32_t index);
//////////////////////////////////////////////
// Experimental APIs for inference
//////////////////////////////////////////////
void set_prepare_config(NNFW_PREPARE_CONFIG config);
//////////////////////////////////////////////
// Experimental APIs for training
//////////////////////////////////////////////
nnfw_train_info train_get_traininfo();
void train_set_traininfo(const nnfw_train_info *info);
template <typename T> void train_set_input(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_input_tensorinfo(this->session, index, &tensor_info);
py::buffer_info buf_info = buffer.request();
const auto buf_shape = buf_info.shape;
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));
ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info));
}
template <typename T> void train_set_expected(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_output_tensorinfo(this->session, index, &tensor_info);
py::buffer_info buf_info = buffer.request();
const auto buf_shape = buf_info.shape;
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));
ensure_status(
nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info));
}
template <typename T> void train_set_output(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_output_tensorinfo(this->session, index, &tensor_info);
NNFW_TYPE type = tensor_info.dtype;
uint32_t output_elements = num_elems(&tensor_info);
size_t length = sizeof(T) * output_elements;
ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length));
}
void train_prepare();
void train(bool update_weights);
float train_get_loss(uint32_t index);
void train_export_circle(const py::str &path);
void train_import_checkpoint(const py::str &path);
void train_export_checkpoint(const py::str &path);
// TODO Add other apis
};
} // namespace python
} // namespace api
} // namespace onert
#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__