Skip to content

Commit dddab8c

Browse files
committed
Integrated tflite runtime for inference
1 parent dc0b6a1 commit dddab8c

File tree

5,855 files changed

+819744
-87
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

5,855 files changed

+819744
-87
lines changed

MODULE.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ bazel_dep(name = "apple_support", version = "1.17.1", repo_name = "build_bazel_a
1717
bazel_dep(name = "curl", version = "8.8.0")
1818
bazel_dep(name = "nlohmann_json", version = "3.11.3")
1919
bazel_dep(name = "hedron_compile_commands", dev_dependency = True)
20+
bazel_dep(name = "flatbuffers", version = "24.3.25")
2021

2122
# Hedron's Compile Commands Extractor for Bazel
2223
git_override(

MODULE.bazel.lock

Lines changed: 45 additions & 86 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

framework/src/vx_context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ vx_char targetModules[][VX_MAX_TARGET_NAME] = {
3434
#endif
3535
"openvx-c_model",
3636
"openvx-onnxRT",
37+
"openvx-ai-server",
38+
"openvx-liteRT",
3739
};
3840

3941
const vx_char extensions[] =

include/VX/vx_corevx_ext.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ enum vx_kernel_ext_e
5252
* \brief The AI Model Server Chatbot kernel.
5353
*/
5454
VX_KERNEL_AIS_CHATBOT = VX_KERNEL_BASE(VX_ID_EDGE_AI, VX_LIBRARY_KHR_BASE) + 0x2,
55+
/*!
56+
* \brief The LiteRT CPU Inference kernel.
57+
*/
58+
VX_KERNEL_LITERT_CPU_INF = VX_KERNEL_BASE(VX_ID_EDGE_AI, VX_LIBRARY_KHR_BASE) + 0x3,
5559
};
5660

5761
/*! \brief addtitional tensor attributes.

kernels/liteRT/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
2+
cc_library(
3+
name = "liteRT_kernels",
4+
srcs = glob([
5+
"*.cpp",
6+
]),
7+
hdrs = glob([
8+
"*.h",
9+
"*.hpp",
10+
]),
11+
includes = [
12+
".",
13+
"//framework/include",
14+
],
15+
deps = [
16+
"//:corevx",
17+
"//third_party:tflite",
18+
"//third_party:tflite-hdrs",
19+
],
20+
visibility = ["//visibility:public"]
21+
)

kernels/liteRT/tflite.hpp

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/**
2+
* @file tflite.hpp
3+
* @brief
4+
* @version 0.1
5+
* @date 2025-04-19
6+
*
7+
* @copyright Copyright (c) 2025
8+
*
9+
*/
10+
#include <cstdio>
11+
#include <cstdlib>
12+
#include <memory>
13+
14+
#include "tensorflow/lite/core/interpreter_builder.h"
15+
#include "tensorflow/lite/kernels/register.h"
16+
#include "tensorflow/lite/interpreter.h"
17+
#include "tensorflow/lite/model_builder.h"
18+
#include "tensorflow/lite/optional_debug_tools.h"
19+
20+
#define TFLITE_MINIMAL_CHECK(x) \
21+
if (!(x)) \
22+
{ \
23+
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
24+
return VX_FAILURE; \
25+
}
26+
27+
/**
28+
* @brief Class to run TFLite models
29+
*
30+
*/
31+
class TFLiteRunner
32+
{
33+
public:
34+
/**
35+
* @brief TFLiteRunner Constructor
36+
*/
37+
TFLiteRunner() : modelLoaded(false) {};
38+
39+
/**
40+
* @brief Initialize the TFLite interpreter (load the model)
41+
* @param filename Path to the ONNX model file
42+
* @return VX_SUCCESS on success, VX_FAILURE otherwise
43+
*/
44+
vx_status init(std::string &filename)
45+
{
46+
TFLITE_MINIMAL_CHECK(false == filename.empty())
47+
48+
if (!modelLoaded)
49+
{
50+
// Load model
51+
model = tflite::FlatBufferModel::BuildFromFile(filename.c_str());
52+
TFLITE_MINIMAL_CHECK(model != nullptr);
53+
54+
// Build the interpreter with the InterpreterBuilder.
55+
// Note: all Interpreters should be built with the InterpreterBuilder,
56+
// which allocates memory for the Interpreter and does various set up
57+
// tasks so that the Interpreter can read the provided model.
58+
tflite::ops::builtin::BuiltinOpResolver resolver;
59+
tflite::InterpreterBuilder builder(*model, resolver);
60+
builder(&interpreter);
61+
TFLITE_MINIMAL_CHECK(interpreter != nullptr);
62+
63+
printf("=== Pre-invoke Interpreter State ===\n");
64+
tflite::PrintInterpreterState(interpreter.get());
65+
}
66+
67+
return VX_SUCCESS;
68+
}
69+
70+
/**
71+
* @brief Validate input/output parameters
72+
* @param inputDims Input tensor dimensions
73+
* @param outputDims Output tensor dimensions
74+
* @return VX_SUCCESS on success, VX_FAILURE otherwise
75+
*/
76+
vx_status validate(std::vector<std::vector<size_t>> &inputDims, std::vector<std::vector<size_t>> &outputDims)
77+
{
78+
vx_status status = VX_SUCCESS;
79+
80+
// Validate input dimensions
81+
if (inputDims.size() != interpreter->inputs().size())
82+
{
83+
fprintf(stderr, "Mismatch in number of input tensors: expected %zu, got %zu\n",
84+
inputDims.size(), interpreter->inputs().size());
85+
return VX_FAILURE;
86+
}
87+
88+
for (std::size_t i = 0; i < interpreter->inputs().size(); ++i)
89+
{
90+
TfLiteTensor *input_tensor = interpreter->tensor(interpreter->inputs()[i]);
91+
if (input_tensor == nullptr)
92+
{
93+
fprintf(stderr, "Input tensor at index %zu is null.\n", i);
94+
return VX_FAILURE;
95+
}
96+
97+
// Get the shape of the input tensor
98+
std::vector<size_t> tensor_shape(input_tensor->dims->size);
99+
for (int j = 0; j < input_tensor->dims->size; ++j)
100+
{
101+
tensor_shape[j] = input_tensor->dims->data[j];
102+
}
103+
104+
// Compare with the expected shape
105+
if (tensor_shape != inputDims[i])
106+
{
107+
fprintf(stderr, "Mismatch in input tensor %zu shape: expected {", i);
108+
for (size_t dim : inputDims[i])
109+
fprintf(stderr, "%zu,", dim);
110+
fprintf(stderr, "} but got {");
111+
for (size_t dim : tensor_shape)
112+
fprintf(stderr, "%zu,", dim);
113+
fprintf(stderr, "}\n");
114+
return VX_FAILURE;
115+
}
116+
}
117+
118+
// Validate output dimensions
119+
if (outputDims.size() != interpreter->outputs().size())
120+
{
121+
fprintf(stderr, "Mismatch in number of output tensors: expected %zu, got %zu\n",
122+
outputDims.size(), interpreter->outputs().size());
123+
return VX_FAILURE;
124+
}
125+
126+
for (std::size_t i = 0; i < interpreter->outputs().size(); ++i)
127+
{
128+
TfLiteTensor *output_tensor = interpreter->tensor(interpreter->outputs()[i]);
129+
if (output_tensor == nullptr)
130+
{
131+
fprintf(stderr, "Output tensor at index %zu is null.\n", i);
132+
return VX_FAILURE;
133+
}
134+
135+
// Get the shape of the output tensor
136+
std::vector<size_t> tensor_shape(output_tensor->dims->size);
137+
for (int j = 0; j < output_tensor->dims->size; ++j)
138+
{
139+
tensor_shape[j] = output_tensor->dims->data[j];
140+
}
141+
142+
// Compare with the expected shape
143+
if (tensor_shape != outputDims[i])
144+
{
145+
fprintf(stderr, "Mismatch in output tensor %zu shape: expected {", i);
146+
for (size_t dim : outputDims[i])
147+
fprintf(stderr, "%zu,", dim);
148+
fprintf(stderr, "} but got {");
149+
for (size_t dim : tensor_shape)
150+
fprintf(stderr, "%zu,", dim);
151+
fprintf(stderr, "}\n");
152+
return VX_FAILURE;
153+
}
154+
}
155+
156+
return status;
157+
}
158+
159+
/**
160+
* @brief Allocate memory for input and output tensors
161+
* @param inputTensors Input tensors
162+
* @param outputTensors Output tensors
163+
* @return VX_SUCCESS on success, VX_FAILURE otherwise
164+
*/
165+
vx_status allocate(std::vector<std::pair<float *, vx_size>> &inputTensors, std::vector<std::pair<float *, vx_size>> &outputTensors)
166+
{
167+
vx_status status = VX_SUCCESS;
168+
169+
// Fill input buffers
170+
// TODO(user): Insert code to fill input tensors.
171+
// Note: The buffer of the input tensor with index `i` of type T can
172+
// be accessed with `T* input = interpreter->typed_input_tensor<T>(i);`
173+
for (std::size_t i = 0; i < interpreter->inputs().size(); ++i)
174+
{
175+
status = bindMemory(interpreter->inputs()[i], inputTensors[i].first, inputTensors[i].second);
176+
}
177+
178+
// Read output buffers
179+
// TODO(user): Insert getting data out code.
180+
// Note: The buffer of the output tensor with index `i` of type T can
181+
// be accessed with `T* output = interpreter->typed_output_tensor<T>(i);`
182+
for (std::size_t i = 0; i < interpreter->outputs().size(); ++i)
183+
{
184+
status |= bindMemory(interpreter->outputs()[i], outputTensors[i].first, outputTensors[i].second);
185+
}
186+
187+
// Allocate tensor buffers.
188+
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
189+
190+
return status;
191+
}
192+
193+
/**
194+
* @brief Run the kernel (execute the model)
195+
* @param inputTensors Input tensors
196+
* @param outputTensosrs Output tensors
197+
* @return VX_SUCCESS on success, VX_FAILURE otherwise
198+
*/
199+
vx_status run()
200+
{
201+
// Run inference
202+
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
203+
printf("\n\n=== Post-invoke Interpreter State ===\n");
204+
tflite::PrintInterpreterState(interpreter.get());
205+
return VX_SUCCESS;
206+
}
207+
208+
private:
209+
bool modelLoaded = false;
210+
std::unique_ptr<tflite::FlatBufferModel> model;
211+
// Pointer to the TFLite interpreter
212+
std::unique_ptr<tflite::Interpreter> interpreter;
213+
214+
/**
215+
* @brief Bind pre-allocated memory to a tensor
216+
* @param tensor_index Index of the tensor to bind
217+
* @param pre_allocated_memory Pointer to the pre-allocated memory
218+
* @param size_in_bytes Size of the pre-allocated memory in bytes
219+
* @return VX_SUCCESS on success, VX_FAILURE otherwise
220+
*/
221+
vx_status bindMemory(int tensor_index, void* pre_allocated_memory, size_t size_in_bytes)
222+
{
223+
vx_status status = VX_SUCCESS;
224+
225+
// Get the tensor
226+
TfLiteTensor* tensor = interpreter->tensor(tensor_index);
227+
228+
// Check if the tensor exists
229+
if (tensor == nullptr)
230+
{
231+
fprintf(stderr, "Tensor at index %d does not exist.\n", tensor_index);
232+
status = VX_FAILURE;
233+
}
234+
235+
// Ensure the tensor type and size match your pre-allocated memory
236+
if (VX_SUCCESS == status &&
237+
tensor->bytes != size_in_bytes)
238+
{
239+
fprintf(stderr, "Pre-allocated memory size (%ld) does not match tensor size (%ld).\n",
240+
size_in_bytes, tensor->bytes);
241+
status = VX_FAILURE;
242+
}
243+
244+
if (VX_SUCCESS == status)
245+
{
246+
// Bind the pre-allocated memory to the tensor
247+
TFLITE_MINIMAL_CHECK(kTfLiteOk == interpreter->SetCustomAllocationForTensor(
248+
tensor_index,
249+
{pre_allocated_memory, size_in_bytes},
250+
kTfLiteCustomAllocationFlagsSkipAlignCheck));
251+
}
252+
253+
return status;
254+
}
255+
};

targets/liteRT/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
2+
cc_library(
3+
name = "liteRT",
4+
srcs = glob([
5+
"*.cpp",
6+
"*.h",
7+
]),
8+
includes = [
9+
".",
10+
"//framework/include",
11+
],
12+
deps = [
13+
"//:corevx",
14+
"//kernels/liteRT:liteRT_kernels",
15+
],
16+
visibility = ["//visibility:public"]
17+
)
18+
19+
cc_shared_library(
20+
name = "openvx-liteRT",
21+
deps = [":liteRT"],
22+
visibility = ["//visibility:public"]
23+
)
24+
25+
cc_import(
26+
name = "imported_openvx_liteRT",
27+
shared_library = ":openvx-liteRT",
28+
visibility = ["//visibility:public"]
29+
)

0 commit comments

Comments
 (0)