Skip to content

Commit 29d465b

Browse files
authored
wasi_nn_tensorflowlite.cpp: make this compatible with wasmedge (#4517)
for wasi_ephemeral_nn, * implement u8 input * stop dealing with quantization. * wasi-nn doesn't have a concept of quantization or pre/post-processing. i can't think of any ways to make the backend perform zero-point/scale processing without risking to break other applications. * there seems to be applications which just use u8 inputs/outputs for a quantized model. (see [1] for an example.) for certain kinds of inputs/outputs, it usually just works. this commit keeps the legacy wasi_nn logic intact for now. tested with [1] with [2] applied. WAMR with this patch: ``` Read graph weights, size in bytes: 3561598 [wasi_nn.c:297 WARNING] load_by_name_with_config() not found [wasi_nn_tensorflowlite.cpp:272 WARNING] Default encoding is CPU. Loaded graph into wasi-nn with ID: Graph#0 Read input tensor, size in bytes: 150528 1.) [166](198)Aix galericulata 2.) [34](1)Gallus gallus domesticus 3.) [158](1)Coccothraustes coccothraustes 4.) [778](1)Sitta europaea 5.) [819](1)Anas platyrhynchos ``` wasmedge: ``` Read graph weights, size in bytes: 3561598 Loaded graph into wasi-nn with ID: Graph#0 Read input tensor, size in bytes: 150528 1.) [166](198)Aix galericulata 2.) [34](1)Gallus gallus domesticus 3.) [158](1)Coccothraustes coccothraustes 4.) [778](1)Sitta europaea 5.) [819](1)Anas platyrhynchos ``` and "Aix galericulata" seems like a reasonable classification of the image to my eyes. [1] https://github.com/second-state/WasmEdge-WASINN-examples/tree/67f174bab59d98c1b52f7367ec0928701dc998f9/tflite-birds_v1-image [2] second-state/WasmEdge-WASINN-examples#204 Related: #3555 #2611
1 parent 272a41d commit 29d465b

File tree

1 file changed

+44
-28
lines changed

1 file changed

+44
-28
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "wasi_nn_backend.h"
1010
#include "wasm_export.h"
1111

12+
#include <tensorflow/lite/c/c_api.h>
1213
#include <tensorflow/lite/interpreter.h>
1314
#include <tensorflow/lite/kernels/register.h>
1415
#include <tensorflow/lite/model.h>
@@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
279280
tensor *input_tensor)
280281
{
281282
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
283+
TfLiteType tfl_type;
282284

283-
if (input_tensor->type != fp32) {
284-
NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type);
285-
return runtime_error;
285+
switch (input_tensor->type) {
286+
case fp32:
287+
tfl_type = TfLiteType::kTfLiteFloat32;
288+
break;
289+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
290+
case u8:
291+
tfl_type = TfLiteType::kTfLiteUInt8;
292+
break;
293+
#endif
294+
default:
295+
NN_ERR_PRINTF("unsupported input tensor type %u",
296+
input_tensor->type);
297+
return runtime_error;
286298
}
287299

288300
wasi_nn_error res;
289301
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
290302
return res;
291303

292-
uint32_t num_tensors =
293-
tfl_ctx->interpreters[ctx].interpreter->inputs().size();
304+
auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get();
305+
306+
uint32_t num_tensors = interpreter->inputs().size();
294307
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
295308
if (index + 1 > num_tensors) {
296309
return runtime_error;
297310
}
298311

299-
auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
312+
auto tensor = interpreter->input_tensor(index);
300313
if (tensor == NULL) {
301314
NN_ERR_PRINTF("Missing memory");
302315
return too_large;
303316
}
304317

318+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
319+
if (TfLiteTensorType(tensor) != tfl_type) {
320+
NN_ERR_PRINTF("Type mismatch");
321+
return runtime_error;
322+
}
323+
324+
if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf,
325+
input_tensor->data.size)
326+
!= kTfLiteOk) {
327+
return runtime_error;
328+
}
329+
#else
305330
uint32_t model_tensor_size = 1;
306331
for (int i = 0; i < tensor->dims->size; ++i)
307332
model_tensor_size *= (uint32_t)tensor->dims->data[i];
@@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
346371
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
347372
}
348373
}
374+
#endif
349375

350376
return success;
351377
}
@@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
388414
return too_large;
389415
}
390416

391-
if (tensor->quantization.type == kTfLiteNoQuantization) {
392-
NN_DBG_PRINTF("No quantization information");
393417
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
394-
if (output_tensor->size < tensor->bytes) {
395-
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
396-
return too_large;
397-
}
418+
size_t sz = TfLiteTensorByteSize(tensor);
419+
if (output_tensor->size < sz) {
420+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
421+
return too_large;
422+
}
423+
if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) {
424+
return runtime_error;
425+
}
426+
*output_tensor_size = sz;
398427
#else
428+
if (tensor->quantization.type == kTfLiteNoQuantization) {
429+
NN_DBG_PRINTF("No quantization information");
399430
/*
400431
* for now, maintain the bug-to-bug compatibility with the old abi,
401432
* where the size here is the number of fp32, not bytes.
@@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
404435
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
405436
return too_large;
406437
}
407-
#endif
408438
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
409439
tensor->bytes);
410-
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
411-
*output_tensor_size = tensor->bytes;
412-
#else
413440
/*
414441
* for now, maintain the bug-to-bug compatibility with the old abi,
415442
* where the size here is the number of fp32, not bytes.
416443
*/
417444
*output_tensor_size = tensor->bytes / sizeof(float);
418-
#endif
419445
}
420446
else { // TODO: Assuming uint8 quantized networks.
421447
TfLiteAffineQuantization *quant_info =
@@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
429455
for (int i = 0; i < (int)tensor->dims->size; ++i)
430456
model_tensor_size *= (uint32_t)tensor->dims->data[i];
431457

432-
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
433-
if (output_tensor->size / sizeof(float) < model_tensor_size) {
434-
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
435-
return too_large;
436-
}
437-
#else
438458
/*
439459
* for now, maintain the bug-to-bug compatibility with the old abi,
440460
* where the size here is the number of fp32, not bytes.
@@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
443463
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
444464
return too_large;
445465
}
446-
#endif
447466

448467
uint8_t *ot = tfl_ctx->interpreters[ctx]
449468
.interpreter->typed_output_tensor<uint8_t>(index);
@@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
458477
output_tensor_f[i] = (ot[i] - zero_point) * scale;
459478
}
460479

461-
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
462-
*output_tensor_size = model_tensor_size * sizeof(float);
463-
#else
464480
/*
465481
* for now, maintain the bug-to-bug compatibility with the old abi,
466482
* where the size here is the number of fp32, not bytes.
467483
*/
468484
*output_tensor_size = model_tensor_size;
469-
#endif
470485
}
486+
#endif
471487

472488
return success;
473489
}

0 commit comments

Comments
 (0)