Skip to content

Commit 2a93b24

Browse files
Performance improvements (~50%) running the model.
Cached some values and removed some checks that were performed more than once. Before model would take ~150 microseconds to run, and after these chanages ~75 microseconds.
1 parent 521c940 commit 2a93b24

File tree

5 files changed

+113
-57
lines changed

5 files changed

+113
-57
lines changed

mlrunner/ml4f.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ int ml4f_full_invoke(const ml4f_header_t *model, const float *input, float *outp
124124
return r;
125125
}
126126

127+
int ml4f_full_invoke_arena(const ml4f_header_t *model, uint8_t *arena, const float *input, float *output) {
128+
if (!ml4f_is_valid_header(model))
129+
return -1;
130+
memcpy(arena + model->input_offset, input,
131+
ml4f_shape_size(ml4f_input_shape(model), model->input_type));
132+
int r = ml4f_invoke(model, arena);
133+
memcpy(output, arena + model->output_offset,
134+
ml4f_shape_size(ml4f_output_shape(model), model->output_type));
135+
return r;
136+
}
137+
127138
int ml4f_full_invoke_argmax(const ml4f_header_t *model, const float *input) {
128139
if (!ml4f_is_valid_header(model))
129140
return -1;

mlrunner/ml4f.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ uint32_t ml4f_shape_size(const uint32_t *shape, uint32_t type);
5353
int ml4f_argmax(const float *data, uint32_t size);
5454

5555
int ml4f_full_invoke(const ml4f_header_t *model, const float *input, float *output);
56+
int ml4f_full_invoke_arena(const ml4f_header_t *model, uint8_t *arena, const float *input, float *output);
5657
int ml4f_full_invoke_argmax(const ml4f_header_t *model, const float *input);
5758

5859
#ifdef __cplusplus

mlrunner/mlrunner.c

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
#include "mlrunner.h"
55

66
// Pointer to the model in flash
7-
static uint32_t* MODEL_ADDRESS = NULL;
7+
static uint32_t *MODEL_ADDRESS = NULL;
8+
static uint8_t *model_arena = NULL;
9+
static size_t input_length = 0;
10+
static size_t output_length = 0;
811

912
/*****************************************************************************/
1013
/* Private API */
1114
/*****************************************************************************/
1215
/**
1316
* @return True if the model header is valid, False otherwise.
1417
*/
15-
static bool is_model_valid(const void* model_address) {
18+
static bool is_model_valid(const void *model_address) {
1619
ml_model_header_t *model_header = (ml_model_header_t *)model_address;
1720
if (model_header->magic0 != MODEL_HEADER_MAGIC0) {
1821
return false;
@@ -39,8 +42,8 @@ static bool is_model_valid(const void* model_address) {
3942
*
4043
* @return The ML4F model or NULL if the model is not present or invalid.
4144
*/
42-
static ml4f_header_t* get_ml4f_model() {
43-
if (MODEL_ADDRESS == NULL || !is_model_valid(MODEL_ADDRESS)) {
45+
static inline ml4f_header_t* get_ml4f_model() {
46+
if (MODEL_ADDRESS == NULL) {
4447
return NULL;
4548
}
4649
ml_model_header_t *model_header = (ml_model_header_t *)MODEL_ADDRESS;
@@ -56,13 +59,41 @@ bool ml_setModel(const void *model_address) {
5659
return false;
5760
}
5861
MODEL_ADDRESS = (uint32_t *)model_address;
62+
63+
// Allocate the model arena
64+
int model_arena_size = ml_getArenaSize();
65+
if (model_arena_size <= 0) {
66+
MODEL_ADDRESS = NULL;
67+
return false;
68+
}
69+
if (model_arena != NULL) {
70+
free(model_arena);
71+
}
72+
model_arena = malloc(model_arena_size);
73+
if (model_arena == NULL) {
74+
MODEL_ADDRESS = NULL;
75+
return false;
76+
}
77+
78+
// Set the cached input and output lengths
79+
ml_getInputLength();
80+
ml_getOutputLength();
81+
5982
return true;
6083
}
6184

6285
bool ml_isModelPresent() {
6386
return MODEL_ADDRESS != NULL;
6487
}
6588

89+
int ml_getArenaSize() {
90+
ml4f_header_t *ml4f_model = get_ml4f_model();
91+
if (ml4f_model == NULL) {
92+
return -1;
93+
}
94+
return ml4f_model->arena_bytes;
95+
}
96+
6697
int ml_getSamplesPeriod() {
6798
const ml_model_header_t* const model_header = (ml_model_header_t*)MODEL_ADDRESS;
6899
if (model_header == NULL) {
@@ -88,19 +119,27 @@ int ml_getSampleDimensions() {
88119
}
89120

90121
int ml_getInputLength() {
91-
ml4f_header_t *ml4f_model = get_ml4f_model();
92-
if (ml4f_model == NULL) {
93-
return -1;
122+
if (input_length == 0) {
123+
ml4f_header_t *ml4f_model = get_ml4f_model();
124+
if (ml4f_model == NULL) {
125+
return -1;
126+
}
127+
input_length = ml4f_shape_elements(ml4f_input_shape(ml4f_model));
94128
}
95-
return ml4f_shape_elements(ml4f_input_shape(ml4f_model));
129+
130+
return input_length;
96131
}
97132

98133
int ml_getOutputLength() {
99-
ml4f_header_t *ml4f_model = get_ml4f_model();
100-
if (ml4f_model == NULL) {
101-
return -1;
134+
if (output_length == 0) {
135+
ml4f_header_t *ml4f_model = get_ml4f_model();
136+
if (ml4f_model == NULL) {
137+
return -1;
138+
}
139+
output_length = ml4f_shape_elements(ml4f_output_shape(ml4f_model));
102140
}
103-
return ml4f_shape_elements(ml4f_output_shape(ml4f_model));
141+
142+
return output_length;
104143
}
105144

106145
// TODO: Remove this function and use ml_getLabels instead
@@ -241,69 +280,53 @@ ml_predictions_t *ml_allocatePredictions() {
241280
return predictions;
242281
}
243282

244-
bool ml_predict(const float *input, const int in_len, const ml_actions_t *actions, ml_predictions_t *predictions_out) {
245-
if (input == NULL || in_len <= 0 || actions == NULL || predictions_out == NULL) {
246-
return false;
247-
}
248-
249-
int model_output_len = ml_getOutputLength();
250-
if (model_output_len <= 0 ||
251-
model_output_len != (int)actions->len ||
252-
model_output_len != (int)predictions_out->len) {
283+
bool ml_predict(const float *input, const size_t in_len, const ml_actions_t *actions, ml_predictions_t *predictions_out) {
284+
if (actions == NULL || actions->len != output_length ||
285+
predictions_out == NULL || predictions_out->len != output_length) {
253286
return false;
254287
}
255288

256-
bool success = ml_runModel(input, in_len, &predictions_out->prediction, predictions_out->len);
289+
bool success = ml_runModel(input, in_len, (float *)&predictions_out->prediction, output_length);
257290
if (!success) {
258291
return false;
259292
}
260-
predictions_out->index = ml_calcPrediction(actions, &predictions_out->prediction, predictions_out->len);
293+
predictions_out->index = ml_calcPrediction(actions, (float *)&predictions_out->prediction, output_length);
261294

262295
return true;
263296
}
264297

265298

266-
bool ml_runModel(const float *input, const int in_len, float* individual_predictions, const int out_len) {
267-
if (individual_predictions == NULL) {
268-
return false;
269-
}
270-
271-
int model_input_len = ml_getInputLength();
272-
if (model_input_len <= 0 || model_input_len != in_len) {
273-
return false;
274-
}
275-
int model_output_len = ml_getOutputLength();
276-
if (model_output_len <= 0 || model_output_len != out_len) {
299+
bool ml_runModel(const float *input, const size_t in_len, float* individual_predictions, const size_t out_len) {
300+
if (input == NULL || individual_predictions == NULL || input_length != in_len || output_length != out_len) {
277301
return false;
278302
}
279303

280304
ml4f_header_t *ml4f_model = get_ml4f_model();
281-
int r = ml4f_full_invoke(ml4f_model, input, individual_predictions);
305+
int r = ml4f_full_invoke_arena(ml4f_model, model_arena, input, individual_predictions);
282306
if (r != 0) {
283307
return false;
284308
}
285309

286310
return true;
287311
}
288312

289-
int ml_calcPrediction(const ml_actions_t *actions, const float* predictions, const int len) {
290-
if (actions == NULL || predictions == NULL || len <= 0 || len != (int)actions->len) {
313+
int ml_calcPrediction(const ml_actions_t *actions, const float* predictions, const size_t len) {
314+
if (actions == NULL || predictions == NULL || len != actions->len) {
291315
return -1;
292316
}
293317

294318
float predictions_above_threshold[len];
295-
for (int i = 0; i < len; i++) {
319+
for (size_t i = 0; i < len; i++) {
296320
if (predictions[i] >= actions->action[i].threshold) {
297321
predictions_above_threshold[i] = predictions[i];
298322
} else {
299323
predictions_above_threshold[i] = 0.0f;
300324
}
301325
}
302326
int max_index = ml4f_argmax(predictions_above_threshold, len);
303-
if (max_index < 0 || max_index >= len) {
327+
if (max_index < 0) {
304328
return -1;
305329
}
306-
307330
// If the max predictionn is 0, then none were above the threshold
308331
if (predictions_above_threshold[max_index] == 0.0f) {
309332
max_index = -1;

mlrunner/mlrunner.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ bool ml_setModel(const void *model_address);
8888
*/
8989
bool ml_isModelPresent();
9090

91+
/**
92+
* @brief Get the arena size that has been allocated to run the loaded model.
93+
*
94+
* @return The size, in bytes, of the arena required for the model.
95+
* Or -1 if the model is not present.
96+
*/
97+
int ml_getArenaSize();
98+
9199
/**
92100
* @brief Get the period between samples required for the model.
93101
*
@@ -179,7 +187,7 @@ ml_predictions_t *ml_allocatePredictions();
179187
* Or -1 if the model is not present, the actions or input length
180188
* doesn't match, or the prediction failed.
181189
*/
182-
bool ml_predict(const float *input, const int in_len, const ml_actions_t *actions, ml_predictions_t *predictions_out);
190+
bool ml_predict(const float *input, const size_t in_len, const ml_actions_t *actions, ml_predictions_t *predictions_out);
183191

184192
/**
185193
* @brief Run the model and return the individual predictions for each action.
@@ -191,7 +199,7 @@ bool ml_predict(const float *input, const int in_len, const ml_actions_t *action
191199
* @return True if the model is present and the model run was successful,
192200
* False otherwise.
193201
*/
194-
bool ml_runModel(const float *input, const int in_len, float* predictions_out, const int out_len);
202+
bool ml_runModel(const float *input, const size_t in_len, float* predictions_out, const size_t out_len);
195203

196204
/**
197205
* @brief Calculate the overall prediction based on the actions thresholds.
@@ -204,7 +212,7 @@ bool ml_runModel(const float *input, const int in_len, float* predictions_out, c
204212
* Or -1 if the model is not present, the actions or predictions length
205213
* doesn't match, or the prediction failed.
206214
*/
207-
int ml_calcPrediction(const ml_actions_t *actions, const float* predictions, const int len);
215+
int ml_calcPrediction(const ml_actions_t *actions, const float* predictions, const size_t len);
208216

209217
#ifdef __cplusplus
210218
} // extern "C"

testextension.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
namespace testrunner {
2222
static ml_actions_t *actions = NULL;
23+
static ml_predictions_t *predictions = NULL;
2324
static bool initialised = false;
2425
static const uint16_t ML_CODAL_TIMER_VALUE = 1;
2526

@@ -40,18 +41,12 @@ namespace testrunner {
4041
void runModel() {
4142
if (!initialised) return;
4243

43-
ml_predictions_t *predictions = ml_allocatePredictions();
44-
if (predictions == NULL) {
45-
DEBUG_PRINT("Failed to allocate memory for predictions\n");
46-
uBit.panic(TEST_RUNNER_ERROR + 9);
47-
}
48-
4944
unsigned int time_start = system_timer_current_time_us();
5045

5146
float *modelData = mlDataProcessor.getProcessedData();
5247
if (modelData == NULL) {
5348
DEBUG_PRINT("Failed to processed data for the model\n");
54-
uBit.panic(TEST_RUNNER_ERROR + 10);
49+
uBit.panic(TEST_RUNNER_ERROR + 21);
5550
}
5651

5752
unsigned int time_mid = system_timer_current_time_us();
@@ -60,7 +55,7 @@ namespace testrunner {
6055
modelData, mlDataProcessor.getProcessedDataSize(), actions, predictions);
6156
if (!success) {
6257
DEBUG_PRINT("Failed to run model\n");
63-
uBit.panic(TEST_RUNNER_ERROR + 11);
58+
uBit.panic(TEST_RUNNER_ERROR + 22);
6459
}
6560

6661
unsigned int time_end = system_timer_current_time_us();
@@ -74,7 +69,7 @@ namespace testrunner {
7469
} else {
7570
DEBUG_PRINT("None\n");
7671
}
77-
DEBUG_PRINT("\tPredictions:");
72+
DEBUG_PRINT("\tIndividual:");
7873
for (size_t i = 0; i < actions->len; i++) {
7974
DEBUG_PRINT(" %s [%d]",
8075
actions->action[i].label,
@@ -83,8 +78,6 @@ namespace testrunner {
8378
DEBUG_PRINT("\n\n");
8479

8580
MicroBitEvent evt(TEST_RUNNER_ID_INFERENCE, predictions->index + 2);
86-
87-
free(predictions);
8881
}
8982

9083
void recordAccData(MicroBitEvent) {
@@ -172,21 +165,41 @@ namespace testrunner {
172165
uBit.panic(TEST_RUNNER_ERROR + 6);
173166
}
174167

168+
const int modelOutputLen = ml_getInputLength();
169+
DEBUG_PRINT("\tModel output length: %d\n", modelOutputLen);
170+
if (modelOutputLen <= 0) {
171+
DEBUG_PRINT("Model output length invalid\n");
172+
uBit.panic(TEST_RUNNER_ERROR + 7);
173+
}
174+
175+
const int modelArenaSize = ml_getArenaSize();
176+
DEBUG_PRINT("\tModel arena size: %d bytes\n", modelArenaSize);
177+
if (modelArenaSize <= 0) {
178+
DEBUG_PRINT("Model arena size length invalid\n");
179+
uBit.panic(TEST_RUNNER_ERROR + 8);
180+
}
181+
175182
actions = ml_allocateActions();
176183
if (actions == NULL) {
177184
DEBUG_PRINT("Failed to allocate memory for actions\n");
178-
uBit.panic(TEST_RUNNER_ERROR + 7);
185+
uBit.panic(TEST_RUNNER_ERROR + 9);
179186
}
180187
const bool getActionsSuccess = ml_getActions(actions);
181188
if (!getActionsSuccess) {
182189
DEBUG_PRINT("Failed to retrieve actions\n");
183-
uBit.panic(TEST_RUNNER_ERROR + 8);
190+
uBit.panic(TEST_RUNNER_ERROR + 10);
184191
}
185192
DEBUG_PRINT("\tActions (%d):\n", actions->len);
186193
for (size_t i = 0; i < actions->len; i++) {
187194
DEBUG_PRINT("\t\t'%s' threshold = %d%%\n", actions->action[i].label, (int)(actions->action[i].threshold * 100));
188195
}
189196

197+
predictions = ml_allocatePredictions();
198+
if (predictions == NULL) {
199+
DEBUG_PRINT("Failed to allocate memory for predictions\n");
200+
uBit.panic(TEST_RUNNER_ERROR + 11);
201+
}
202+
190203
const MlDataProcessorConfig_t mlDataConfig = {
191204
.samples = samplesLen,
192205
.dimensions = sampleDimensions,
@@ -198,7 +211,7 @@ namespace testrunner {
198211
if (mlInitResult != MLDP_SUCCESS) {
199212
DEBUG_PRINT("Failed to initialise ML data processor (%d)\n", mlInitResult);
200213
// TODO: Check error type and set panic value accordingly
201-
uBit.panic(TEST_RUNNER_ERROR + 8);
214+
uBit.panic(TEST_RUNNER_ERROR + 12);
202215
}
203216

204217
// Set up background timer to collect data and run model

0 commit comments

Comments
 (0)