@@ -174,7 +174,7 @@ convert_wasi_nn_type_to_ort_type(tensor_type type,
174
174
#endif
175
175
default :
176
176
NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
177
- return false ; // Default to float
177
+ return false ;
178
178
}
179
179
return true ;
180
180
}
@@ -418,13 +418,17 @@ __attribute__((visibility("default"))) wasi_nn_error
418
418
init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx)
419
419
{
420
420
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
421
+ if (!onnx_ctx) {
422
+ return runtime_error;
423
+ }
424
+
425
+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
421
426
422
427
if (g >= MAX_GRAPHS || !ort_ctx->graphs [g].is_initialized ) {
423
428
NN_ERR_PRINTF (" Invalid graph handle: %d" , g);
424
429
return invalid_argument;
425
430
}
426
431
427
- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
428
432
int ctx_index = -1 ;
429
433
for (int i = 0 ; i < MAX_CONTEXTS; i++) {
430
434
if (!ort_ctx->exec_ctxs [i].is_initialized ) {
@@ -516,6 +520,11 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
516
520
tensor *input_tensor)
517
521
{
518
522
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
523
+ if (!onnx_ctx) {
524
+ return runtime_error;
525
+ }
526
+
527
+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
519
528
520
529
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
521
530
NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
@@ -528,7 +537,6 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
528
537
return invalid_argument;
529
538
}
530
539
531
- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
532
540
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
533
541
534
542
OrtTypeInfo *type_info = nullptr ;
@@ -605,13 +613,17 @@ __attribute__((visibility("default"))) wasi_nn_error
605
613
compute(void *onnx_ctx, graph_execution_context ctx)
606
614
{
607
615
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
616
+ if (!onnx_ctx) {
617
+ return runtime_error;
618
+ }
619
+
620
+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
608
621
609
622
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
610
623
NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
611
624
return invalid_argument;
612
625
}
613
626
614
- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
615
627
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
616
628
617
629
std::vector<OrtValue *> input_values;
@@ -657,6 +669,11 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
657
669
tensor_data *out_buffer, uint32_t *out_buffer_size)
658
670
{
659
671
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
672
+ if (!onnx_ctx) {
673
+ return runtime_error;
674
+ }
675
+
676
+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
660
677
661
678
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
662
679
NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
@@ -669,7 +686,6 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
669
686
return invalid_argument;
670
687
}
671
688
672
- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
673
689
OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
674
690
675
691
OrtValue *output_value = exec_ctx->outputs [index];
0 commit comments