44#include "mlrunner.h"
55
66// Pointer to the model in flash
7- static uint32_t * MODEL_ADDRESS = NULL ;
7+ static uint32_t * MODEL_ADDRESS = NULL ;
8+ static uint8_t * model_arena = NULL ;
9+ static size_t input_length = 0 ;
10+ static size_t output_length = 0 ;
811
912/*****************************************************************************/
1013/* Private API */
1114/*****************************************************************************/
1215/**
1316 * @return True if the model header is valid, False otherwise.
1417 */
15- static bool is_model_valid (const void * model_address ) {
18+ static bool is_model_valid (const void * model_address ) {
1619 ml_model_header_t * model_header = (ml_model_header_t * )model_address ;
1720 if (model_header -> magic0 != MODEL_HEADER_MAGIC0 ) {
1821 return false;
@@ -39,8 +42,8 @@ static bool is_model_valid(const void* model_address) {
3942 *
4043 * @return The ML4F model or NULL if the model is not present or invalid.
4144 */
42- static ml4f_header_t * get_ml4f_model () {
43- if (MODEL_ADDRESS == NULL || ! is_model_valid ( MODEL_ADDRESS ) ) {
45+ static inline ml4f_header_t * get_ml4f_model () {
46+ if (MODEL_ADDRESS == NULL ) {
4447 return NULL ;
4548 }
4649 ml_model_header_t * model_header = (ml_model_header_t * )MODEL_ADDRESS ;
@@ -56,13 +59,41 @@ bool ml_setModel(const void *model_address) {
5659 return false;
5760 }
5861 MODEL_ADDRESS = (uint32_t * )model_address ;
62+
63+ // Allocate the model arena
64+ int model_arena_size = ml_getArenaSize ();
65+ if (model_arena_size <= 0 ) {
66+ MODEL_ADDRESS = NULL ;
67+ return false;
68+ }
69+ if (model_arena != NULL ) {
70+ free (model_arena );
71+ }
72+ model_arena = malloc (model_arena_size );
73+ if (model_arena == NULL ) {
74+ MODEL_ADDRESS = NULL ;
75+ return false;
76+ }
77+
78+ // Set the cached input and output lengths
79+ ml_getInputLength ();
80+ ml_getOutputLength ();
81+
5982 return true;
6083}
6184
6285bool ml_isModelPresent () {
6386 return MODEL_ADDRESS != NULL ;
6487}
6588
89+ int ml_getArenaSize () {
90+ ml4f_header_t * ml4f_model = get_ml4f_model ();
91+ if (ml4f_model == NULL ) {
92+ return -1 ;
93+ }
94+ return ml4f_model -> arena_bytes ;
95+ }
96+
6697int ml_getSamplesPeriod () {
6798 const ml_model_header_t * const model_header = (ml_model_header_t * )MODEL_ADDRESS ;
6899 if (model_header == NULL ) {
@@ -88,19 +119,27 @@ int ml_getSampleDimensions() {
88119}
89120
90121int ml_getInputLength () {
91- ml4f_header_t * ml4f_model = get_ml4f_model ();
92- if (ml4f_model == NULL ) {
93- return -1 ;
122+ if (input_length == 0 ) {
123+ ml4f_header_t * ml4f_model = get_ml4f_model ();
124+ if (ml4f_model == NULL ) {
125+ return -1 ;
126+ }
127+ input_length = ml4f_shape_elements (ml4f_input_shape (ml4f_model ));
94128 }
95- return ml4f_shape_elements (ml4f_input_shape (ml4f_model ));
129+
130+ return input_length ;
96131}
97132
98133int ml_getOutputLength () {
99- ml4f_header_t * ml4f_model = get_ml4f_model ();
100- if (ml4f_model == NULL ) {
101- return -1 ;
134+ if (output_length == 0 ) {
135+ ml4f_header_t * ml4f_model = get_ml4f_model ();
136+ if (ml4f_model == NULL ) {
137+ return -1 ;
138+ }
139+ output_length = ml4f_shape_elements (ml4f_output_shape (ml4f_model ));
102140 }
103- return ml4f_shape_elements (ml4f_output_shape (ml4f_model ));
141+
142+ return output_length ;
104143}
105144
106145// TODO: Remove this function and use ml_getLabels instead
@@ -241,69 +280,53 @@ ml_predictions_t *ml_allocatePredictions() {
241280 return predictions ;
242281}
243282
244- bool ml_predict (const float * input , const int in_len , const ml_actions_t * actions , ml_predictions_t * predictions_out ) {
245- if (input == NULL || in_len <= 0 || actions == NULL || predictions_out == NULL ) {
246- return false;
247- }
248-
249- int model_output_len = ml_getOutputLength ();
250- if (model_output_len <= 0 ||
251- model_output_len != (int )actions -> len ||
252- model_output_len != (int )predictions_out -> len ) {
283+ bool ml_predict (const float * input , const size_t in_len , const ml_actions_t * actions , ml_predictions_t * predictions_out ) {
284+ if (actions == NULL || actions -> len != output_length ||
285+ predictions_out == NULL || predictions_out -> len != output_length ) {
253286 return false;
254287 }
255288
256- bool success = ml_runModel (input , in_len , & predictions_out -> prediction , predictions_out -> len );
289+ bool success = ml_runModel (input , in_len , ( float * ) & predictions_out -> prediction , output_length );
257290 if (!success ) {
258291 return false;
259292 }
260- predictions_out -> index = ml_calcPrediction (actions , & predictions_out -> prediction , predictions_out -> len );
293+ predictions_out -> index = ml_calcPrediction (actions , ( float * ) & predictions_out -> prediction , output_length );
261294
262295 return true;
263296}
264297
265298
266- bool ml_runModel (const float * input , const int in_len , float * individual_predictions , const int out_len ) {
267- if (individual_predictions == NULL ) {
268- return false;
269- }
270-
271- int model_input_len = ml_getInputLength ();
272- if (model_input_len <= 0 || model_input_len != in_len ) {
273- return false;
274- }
275- int model_output_len = ml_getOutputLength ();
276- if (model_output_len <= 0 || model_output_len != out_len ) {
299+ bool ml_runModel (const float * input , const size_t in_len , float * individual_predictions , const size_t out_len ) {
300+ if (input == NULL || individual_predictions == NULL || input_length != in_len || output_length != out_len ) {
277301 return false;
278302 }
279303
280304 ml4f_header_t * ml4f_model = get_ml4f_model ();
281- int r = ml4f_full_invoke (ml4f_model , input , individual_predictions );
305+ int r = ml4f_full_invoke_arena (ml4f_model , model_arena , input , individual_predictions );
282306 if (r != 0 ) {
283307 return false;
284308 }
285309
286310 return true;
287311}
288312
289- int ml_calcPrediction (const ml_actions_t * actions , const float * predictions , const int len ) {
290- if (actions == NULL || predictions == NULL || len <= 0 || len != ( int ) actions -> len ) {
313+ int ml_calcPrediction (const ml_actions_t * actions , const float * predictions , const size_t len ) {
314+ if (actions == NULL || predictions == NULL || len != actions -> len ) {
291315 return -1 ;
292316 }
293317
294318 float predictions_above_threshold [len ];
295- for (int i = 0 ; i < len ; i ++ ) {
319+ for (size_t i = 0 ; i < len ; i ++ ) {
296320 if (predictions [i ] >= actions -> action [i ].threshold ) {
297321 predictions_above_threshold [i ] = predictions [i ];
298322 } else {
299323 predictions_above_threshold [i ] = 0.0f ;
300324 }
301325 }
302326 int max_index = ml4f_argmax (predictions_above_threshold , len );
303- if (max_index < 0 || max_index >= len ) {
327+ if (max_index < 0 ) {
304328 return -1 ;
305329 }
306-
307330 // If the max predictionn is 0, then none were above the threshold
308331 if (predictions_above_threshold [max_index ] == 0.0f ) {
309332 max_index = -1 ;
0 commit comments