Skip to content

Commit 801eb2b

Browse files
put checks under the lock
1 parent 4b50341 commit 801eb2b

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ convert_wasi_nn_type_to_ort_type(tensor_type type,
174174
#endif
175175
default:
176176
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type);
177-
return false; // Default to float
177+
return false;
178178
}
179179
return true;
180180
}
@@ -418,13 +418,17 @@ __attribute__((visibility("default"))) wasi_nn_error
418418
init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx)
419419
{
420420
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
421+
if (!onnx_ctx) {
422+
return runtime_error;
423+
}
424+
425+
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
421426

422427
if (g >= MAX_GRAPHS || !ort_ctx->graphs[g].is_initialized) {
423428
NN_ERR_PRINTF("Invalid graph handle: %d", g);
424429
return invalid_argument;
425430
}
426431

427-
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
428432
int ctx_index = -1;
429433
for (int i = 0; i < MAX_CONTEXTS; i++) {
430434
if (!ort_ctx->exec_ctxs[i].is_initialized) {
@@ -516,6 +520,11 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
516520
tensor *input_tensor)
517521
{
518522
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
523+
if (!onnx_ctx) {
524+
return runtime_error;
525+
}
526+
527+
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
519528

520529
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) {
521530
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
@@ -528,7 +537,6 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
528537
return invalid_argument;
529538
}
530539

531-
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
532540
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx];
533541

534542
OrtTypeInfo *type_info = nullptr;
@@ -605,13 +613,17 @@ __attribute__((visibility("default"))) wasi_nn_error
605613
compute(void *onnx_ctx, graph_execution_context ctx)
606614
{
607615
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
616+
if (!onnx_ctx) {
617+
return runtime_error;
618+
}
619+
620+
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
608621

609622
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) {
610623
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
611624
return invalid_argument;
612625
}
613626

614-
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
615627
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx];
616628

617629
std::vector<OrtValue *> input_values;
@@ -657,6 +669,11 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
657669
tensor_data *out_buffer, uint32_t *out_buffer_size)
658670
{
659671
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
672+
if (!onnx_ctx) {
673+
return runtime_error;
674+
}
675+
676+
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
660677

661678
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) {
662679
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
@@ -669,7 +686,6 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
669686
return invalid_argument;
670687
}
671688

672-
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
673689
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx];
674690

675691
OrtValue *output_value = exec_ctx->outputs[index];

0 commit comments

Comments
 (0)