diff --git a/build-scripts/config_common.cmake b/build-scripts/config_common.cmake index d89c3a285e..366abb16ab 100644 --- a/build-scripts/config_common.cmake +++ b/build-scripts/config_common.cmake @@ -546,7 +546,8 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) # Variant backends if (NOT WAMR_BUILD_WASI_NN_TFLITE EQUAL 1 AND NOT WAMR_BUILD_WASI_NN_OPENVINO EQUAL 1 AND - NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) + NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1 AND + NOT WAMR_BUILD_WASI_NN_ONNX EQUAL 1) message (FATAL_ERROR " Need to select a backend for WASI-NN") endif () @@ -562,6 +563,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) message (" WASI-NN: backend llamacpp enabled") add_definitions (-DWASM_ENABLE_WASI_NN_LLAMACPP) endif () + if (WAMR_BUILD_WASI_NN_ONNX EQUAL 1) + message (" WASI-NN: backend onnx enabled") + add_definitions (-DWASM_ENABLE_WASI_NN_ONNX) + endif () # Variant devices if (WAMR_BUILD_WASI_NN_ENABLE_GPU EQUAL 1) message (" WASI-NN: GPU enabled") diff --git a/core/iwasm/libraries/wasi-nn/README.md b/core/iwasm/libraries/wasi-nn/README.md index 2e926a0327..e16891a1ba 100644 --- a/core/iwasm/libraries/wasi-nn/README.md +++ b/core/iwasm/libraries/wasi-nn/README.md @@ -26,6 +26,7 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 ... - `WAMR_BUILD_WASI_NN_TFLITE`. This option designates TensorFlow Lite as the backend. - `WAMR_BUILD_WASI_NN_OPENVINO`. This option designates OpenVINO as the backend. - `WAMR_BUILD_WASI_NN_LLAMACPP`. This option designates Llama.cpp as the backend. +- `WAMR_BUILD_WASI_NN_ONNX`. This option designates ONNX Runtime as the backend. ### Wasm @@ -151,7 +152,7 @@ docker run \ Supported: -- Graph encoding: `tensorflowlite`, `openvino` and `ggml` +- Graph encoding: `tensorflowlite`, `openvino`, `ggml` and `onnx` - Execution target: `cpu` for all. `gpu` and `tpu` for `tensorflowlite`. - Tensor type: `fp32`. diff --git a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake new file mode 100644 index 0000000000..db8f287e36 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake @@ -0,0 +1,86 @@ +# Copyright 2025 Sony Semiconductor Solutions Corporation. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Find ONNX Runtime library +# +# This module defines the following variables: +# +# :: +# +# onnxruntime_FOUND - True if onnxruntime is found +# onnxruntime_INCLUDE_DIRS - Include directories for onnxruntime +# onnxruntime_LIBRARIES - List of libraries for onnxruntime +# onnxruntime_VERSION - Version of onnxruntime +# +# :: +# +# Example usage: +# +# find_package(onnxruntime) +# if(onnxruntime_FOUND) +# target_link_libraries(app onnxruntime) +# endif() + +# First try to find ONNX Runtime using the CMake config file +# FIXME: This is a temporary workaround for ONNX Runtime's broken CMake config on Linux. +# See https://github.com/microsoft/onnxruntime/issues/25279 +# Once the upstream issue is fixed, this conditional can be safely removed. +if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + find_package(onnxruntime CONFIG QUIET) + if(onnxruntime_FOUND) + return() + endif() +endif() + +# If not found via CMake config, try to find manually +find_path(onnxruntime_INCLUDE_DIR + NAMES onnxruntime_c_api.h + PATHS + /usr/include + /usr/local/include + /opt/onnxruntime/include + $ENV{ONNXRUNTIME_ROOT}/include + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. +) + +find_library(onnxruntime_LIBRARY + NAMES onnxruntime + PATHS + /usr/lib + /usr/local/lib + /opt/onnxruntime/lib + $ENV{ONNXRUNTIME_ROOT}/lib + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. +) + +# Try to determine version from header file +if(onnxruntime_INCLUDE_DIR) + file(STRINGS "${onnxruntime_INCLUDE_DIR}/onnxruntime_c_api.h" onnxruntime_version_str + REGEX "^#define[\t ]+ORT_API_VERSION[\t ]+[0-9]+") + + if(onnxruntime_version_str) + string(REGEX REPLACE "^#define[\t ]+ORT_API_VERSION[\t ]+([0-9]+)" "\\1" + onnxruntime_VERSION "${onnxruntime_version_str}") + endif() +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(onnxruntime + REQUIRED_VARS onnxruntime_LIBRARY onnxruntime_INCLUDE_DIR + VERSION_VAR onnxruntime_VERSION +) + +if(onnxruntime_FOUND) + set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY}) + set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR}) + + if(NOT TARGET onnxruntime::onnxruntime) + add_library(onnxruntime::onnxruntime UNKNOWN IMPORTED) + set_target_properties(onnxruntime::onnxruntime PROPERTIES + IMPORTED_LOCATION "${onnxruntime_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}" + ) + endif() +endif() + +mark_as_advanced(onnxruntime_INCLUDE_DIR onnxruntime_LIBRARY) diff --git a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake index b771b1c402..56a7b44e4a 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake @@ -109,3 +109,24 @@ if(WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) install(TARGETS wasi_nn_llamacpp DESTINATION lib) endif() + +# - onnx +if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1) + find_package(onnxruntime REQUIRED) + enable_language(CXX) + + add_library( + wasi_nn_onnx + SHARED + ${WASI_NN_ROOT}/src/wasi_nn_onnx.cpp + ) + + target_link_libraries( + wasi_nn_onnx + PUBLIC + vmlib + onnxruntime::onnxruntime + ) + + install(TARGETS wasi_nn_onnx DESTINATION lib) +endif() diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 7921ec9539..787c3a432d 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -33,6 +33,7 @@ #define TFLITE_BACKEND_LIB "libwasi_nn_tflite" LIB_EXTENTION #define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION +#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION /* Global variables */ static korp_mutex wasi_nn_lock; @@ -240,6 +241,17 @@ choose_a_backend() return openvino; } +#ifndef NDEBUG + NN_WARN_PRINTF("%s", dlerror()); +#endif + + handle = dlopen(ONNX_BACKEND_LIB, RTLD_LAZY); + if (handle) { + NN_INFO_PRINTF("Using onnx backend"); + dlclose(handle); + return onnx; + } + #ifndef NDEBUG NN_WARN_PRINTF("%s", dlerror()); #endif @@ -363,6 +375,8 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding) return TFLITE_BACKEND_LIB; case ggml: return LLAMACPP_BACKEND_LIB; + case onnx: + return ONNX_BACKEND_LIB; default: return NULL; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp new file mode 100644 index 0000000000..44d8d66135 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -0,0 +1,795 @@ +/* + * Copyright 2025 Sony Semiconductor Solutions Corporation. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +#include +#include +#include +#include +#include +#include +#include "bh_platform.h" +#include "wasi_nn_backend.h" +#include "utils/logger.h" +#include "onnxruntime_c_api.h" + +#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 +#error This backend doesn't support legacy "wasi_nn" abi. Please enable WASM_ENABLE_WASI_EPHEMERAL_NN. +#endif + +/* Maximum number of graphs and execution contexts */ +#define MAX_GRAPHS 4 +#define MAX_CONTEXTS 4 + +/* Graph structure */ +typedef struct { + OrtSession *session; + bool is_initialized; +} OnnxRuntimeGraph; + +/* Execution context structure */ +typedef struct { + OrtMemoryInfo *memory_info; + std::vector input_names; + std::vector output_names; + std::unordered_map inputs; + std::unordered_map outputs; + OnnxRuntimeGraph *graph; + bool is_initialized; +} OnnxRuntimeExecCtx; + +/* ONNX Runtime context structure */ +typedef struct { + OrtEnv *env; + OrtSessionOptions *session_options; + OrtAllocator *allocator; + const OrtApi *ort_api; + std::mutex mutex; + bool is_initialized; + OnnxRuntimeGraph graphs[MAX_GRAPHS]; + OnnxRuntimeExecCtx exec_ctxs[MAX_CONTEXTS]; +} OnnxRuntimeContext; + +static wasi_nn_error +convert_ort_error_to_wasi_nn_error(const OnnxRuntimeContext *ctx, + OrtStatus *status) +{ + if (status == nullptr) { + return success; + } + + wasi_nn_error err; + OrtErrorCode code = ctx->ort_api->GetErrorCode(status); + const char *msg = ctx->ort_api->GetErrorMessage(status); + + NN_ERR_PRINTF("ONNX Runtime error: %s", msg); + + switch (code) { + case ORT_INVALID_ARGUMENT: + err = invalid_argument; + break; + case ORT_RUNTIME_EXCEPTION: + err = runtime_error; + break; + case ORT_NOT_IMPLEMENTED: + err = unsupported_operation; + break; + case ORT_INVALID_PROTOBUF: + err = invalid_encoding; + break; + case ORT_MODEL_LOADED: + err = too_large; + break; + case ORT_INVALID_GRAPH: + err = invalid_encoding; + break; + default: + err = runtime_error; + break; + } + + ctx->ort_api->ReleaseStatus(status); + return err; +} + +static bool +convert_wasi_nn_type_to_ort_type(tensor_type type, + ONNXTensorElementDataType *ort_type) +{ + switch (type) { + case fp32: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + break; + case fp16: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + break; + case fp64: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + break; + case u8: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + break; + case i32: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + break; + case i64: + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + break; + default: + NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); + return false; + } + return true; +} + +/* Backend API implementation */ + +extern "C" { + +__attribute__((visibility("default"))) wasi_nn_error +init_backend(void **onnx_ctx) +{ + wasi_nn_error err = success; + OrtStatus *status = nullptr; + OnnxRuntimeContext *ctx = nullptr; + ctx = new OnnxRuntimeContext(); + ctx->ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + if (!ctx->ort_api) { + NN_ERR_PRINTF("Failed to get ONNX Runtime API"); + err = runtime_error; + goto fail; + } + + NN_INFO_PRINTF("Creating ONNX Runtime environment..."); + status = ctx->ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, "wasi-nn", + &ctx->env); + if (status != nullptr) { + const char *error_message = ctx->ort_api->GetErrorMessage(status); + err = convert_ort_error_to_wasi_nn_error(ctx, status); + NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s", + error_message); + goto fail; + } + NN_INFO_PRINTF("ONNX Runtime environment created successfully"); + + status = ctx->ort_api->CreateSessionOptions(&ctx->session_options); + if (status != nullptr) { + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseEnv(ctx->env); + NN_ERR_PRINTF("Failed to create ONNX Runtime session options"); + goto fail; + } + + status = ctx->ort_api->SetSessionGraphOptimizationLevel( + ctx->session_options, ORT_ENABLE_BASIC); + if (status != nullptr) { + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); + NN_ERR_PRINTF("Failed to set graph optimization level"); + goto fail; + } + + status = ctx->ort_api->GetAllocatorWithDefaultOptions(&ctx->allocator); + if (status != nullptr) { + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); + NN_ERR_PRINTF("Failed to get default allocator"); + goto fail; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + ctx->graphs[i].is_initialized = false; + ctx->graphs[i].session = nullptr; + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + ctx->exec_ctxs[i].is_initialized = false; + ctx->exec_ctxs[i].memory_info = nullptr; + ctx->exec_ctxs[i].graph = nullptr; + ctx->exec_ctxs[i].input_names.clear(); + ctx->exec_ctxs[i].output_names.clear(); + ctx->exec_ctxs[i].inputs.clear(); + ctx->exec_ctxs[i].outputs.clear(); + } + + ctx->is_initialized = true; + *onnx_ctx = ctx; + + NN_INFO_PRINTF("ONNX Runtime backend initialized"); + return success; + +fail: + delete (ctx); + return err; +} + +__attribute__((visibility("default"))) wasi_nn_error +deinit_backend(void *onnx_ctx) +{ + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + if (!ctx->is_initialized) { + return success; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + if (ctx->graphs[i].is_initialized) { + ctx->ort_api->ReleaseSession(ctx->graphs[i].session); + ctx->graphs[i].is_initialized = false; + } + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (ctx->exec_ctxs[i].is_initialized) { + for (auto &input : ctx->exec_ctxs[i].inputs) { + ctx->ort_api->ReleaseValue(input.second); + } + for (auto &output : ctx->exec_ctxs[i].outputs) { + ctx->ort_api->ReleaseValue(output.second); + } + + for (auto name : ctx->exec_ctxs[i].input_names) { + free((void *)name); + } + ctx->exec_ctxs[i].input_names.clear(); + + for (auto name : ctx->exec_ctxs[i].output_names) { + free((void *)name); + } + ctx->exec_ctxs[i].output_names.clear(); + + ctx->ort_api->ReleaseMemoryInfo(ctx->exec_ctxs[i].memory_info); + ctx->exec_ctxs[i].is_initialized = false; + } + } + + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); + ctx->is_initialized = false; + + delete (ctx); + + NN_INFO_PRINTF("ONNX Runtime backend deinitialized"); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, + execution_target target, graph *g) +{ + if (!onnx_ctx) { + return runtime_error; + } + + if (encoding != onnx) { + NN_ERR_PRINTF("Unsupported encoding: %d", encoding); + return invalid_encoding; + } + + if (target != cpu) { + NN_ERR_PRINTF("Only CPU target is supported"); + return unsupported_operation; + } + + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!ctx->graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + if (builder->size == 0 || builder->buf == NULL) { + NN_ERR_PRINTF("No model data provided"); + return invalid_argument; + } + + NN_INFO_PRINTF("[ONNX Runtime] Loading model of size %zu bytes...", + builder->buf[0].size); + + if (builder->buf[0].size > 16) { + NN_INFO_PRINTF( + "Model header bytes: %02x %02x %02x %02x %02x %02x %02x %02x", + ((uint8_t *)builder->buf[0].buf)[0], + ((uint8_t *)builder->buf[0].buf)[1], + ((uint8_t *)builder->buf[0].buf)[2], + ((uint8_t *)builder->buf[0].buf)[3], + ((uint8_t *)builder->buf[0].buf)[4], + ((uint8_t *)builder->buf[0].buf)[5], + ((uint8_t *)builder->buf[0].buf)[6], + ((uint8_t *)builder->buf[0].buf)[7]); + } + + OrtStatus *status = ctx->ort_api->CreateSessionFromArray( + ctx->env, builder->buf[0].buf, builder->buf[0].size, + ctx->session_options, &ctx->graphs[graph_index].session); + + if (status != nullptr) { + const char *error_message = ctx->ort_api->GetErrorMessage(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ctx, status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s", + error_message); + return err; + } + + NN_INFO_PRINTF("ONNX Runtime session created successfully"); + + ctx->graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) +{ + if (!onnx_ctx) { + return runtime_error; + } + + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!ctx->graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + OrtStatus *status = + ctx->ort_api->CreateSession(ctx->env, name, ctx->session_options, + &ctx->graphs[graph_index].session); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ctx, status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s", + name); + return err; + } + + ctx->graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name, + graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) +{ + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); + + if (g >= MAX_GRAPHS || !ort_ctx->graphs[g].is_initialized) { + NN_ERR_PRINTF("Invalid graph handle: %d", g); + return invalid_argument; + } + + int ctx_index = -1; + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (!ort_ctx->exec_ctxs[i].is_initialized) { + ctx_index = i; + break; + } + } + + if (ctx_index == -1) { + NN_ERR_PRINTF("Maximum number of execution contexts reached"); + return runtime_error; + } + + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx_index]; + exec_ctx->graph = &ort_ctx->graphs[g]; + + OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo( + OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + NN_ERR_PRINTF("Failed to create CPU memory info"); + return err; + } + + size_t num_input_nodes; + status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session, + &num_input_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input count"); + return err; + } + + for (size_t i = 0; i < num_input_nodes; i++) { + char *input_name; + status = ort_ctx->ort_api->SessionGetInputName( + exec_ctx->graph->session, i, ort_ctx->allocator, &input_name); + if (status != nullptr) { + wasi_nn_error err = + convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input name"); + return err; + } + exec_ctx->input_names.push_back(input_name); + } + + size_t num_output_nodes; + status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session, + &num_output_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output count"); + return err; + } + + for (size_t i = 0; i < num_output_nodes; i++) { + char *output_name; + status = ort_ctx->ort_api->SessionGetOutputName( + exec_ctx->graph->session, i, ort_ctx->allocator, &output_name); + if (status != nullptr) { + wasi_nn_error err = + convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output name"); + return err; + } + exec_ctx->output_names.push_back(output_name); + } + + exec_ctx->is_initialized = true; + *ctx = ctx_index; + + NN_INFO_PRINTF("Execution context %d initialized for graph %d", ctx_index, + g); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, + tensor *input_tensor) +{ + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= ort_ctx->exec_ctxs[ctx].input_names.size()) { + NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index, + ort_ctx->exec_ctxs[ctx].input_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; + + OrtTypeInfo *type_info = nullptr; + OrtStatus *status = ort_ctx->ort_api->SessionGetInputTypeInfo( + exec_ctx->graph->session, index, &type_info); + if (status != nullptr) { + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + return runtime_error; + } + + const OrtTensorTypeAndShapeInfo *tensor_info; + status = + ort_ctx->ort_api->CastTypeInfoToTensorInfo(type_info, &tensor_info); + if (status != nullptr) { + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + return runtime_error; + } + + size_t num_model_dims; + status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_model_dims); + std::vector model_dims(num_model_dims); + status = ort_ctx->ort_api->GetDimensions(tensor_info, model_dims.data(), + num_model_dims); + + void *input_tensor_data = input_tensor->data.buf; + void *input_tensor_scaled_data = NULL; + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + size_t num_dims = input_tensor->dimensions->size; + int64_t *ort_dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!ort_dims) { + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + for (size_t i = 0; i < num_dims; i++) { + ort_dims[i] = input_tensor->dimensions->buf[i]; + } + + ONNXTensorElementDataType ort_type; + if (!convert_wasi_nn_type_to_ort_type( + static_cast(input_tensor->type), &ort_type)) { + NN_ERR_PRINTF("Failed to convert tensor type"); + return runtime_error; + } + + OrtValue *input_value = nullptr; + size_t total_elements = 1; + for (size_t i = 0; i < num_dims; i++) { + total_elements *= input_tensor->dimensions->buf[i]; + } + + status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( + exec_ctx->memory_info, input_tensor->data.buf, input_tensor->data.size, + ort_dims, num_dims, ort_type, &input_value); + + free(ort_dims); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + NN_ERR_PRINTF("Failed to create input tensor"); + return err; + } + + if (exec_ctx->inputs.count(index) > 0) { + ort_ctx->ort_api->ReleaseValue(exec_ctx->inputs[index]); + } + exec_ctx->inputs[index] = input_value; + + NN_INFO_PRINTF("Input tensor set for context %d, index %d", ctx, index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +compute(void *onnx_ctx, graph_execution_context ctx) +{ + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; + + std::vector input_values; + std::vector input_names; + + for (size_t i = 0; i < exec_ctx->input_names.size(); i++) { + if (exec_ctx->inputs.count(i) == 0) { + NN_ERR_PRINTF("Input tensor not set for index %zu", i); + return invalid_argument; + } + input_values.push_back(exec_ctx->inputs[i]); + input_names.push_back(exec_ctx->input_names[i]); + } + + for (auto &output : exec_ctx->outputs) { + ort_ctx->ort_api->ReleaseValue(output.second); + } + exec_ctx->outputs.clear(); + + std::vector output_values(exec_ctx->output_names.size()); + + OrtStatus *status = ort_ctx->ort_api->Run( + exec_ctx->graph->session, nullptr, input_names.data(), + input_values.data(), input_values.size(), exec_ctx->output_names.data(), + exec_ctx->output_names.size(), output_values.data()); + + for (size_t i = 0; i < output_values.size(); i++) { + exec_ctx->outputs[i] = output_values[i]; + } + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + NN_ERR_PRINTF("Failed to run inference"); + return err; + } + + NN_INFO_PRINTF("Inference computed for context %d", ctx); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, + tensor_data *out_buffer, uint32_t *out_buffer_size) +{ + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= ort_ctx->exec_ctxs[ctx].output_names.size()) { + NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index, + ort_ctx->exec_ctxs[ctx].output_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; + + OrtValue *output_value = exec_ctx->outputs[index]; + if (!output_value) { + NN_ERR_PRINTF("Output tensor not available for index %d", index); + return runtime_error; + } + + OrtTensorTypeAndShapeInfo *tensor_info; + OrtStatus *status = + ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + NN_ERR_PRINTF("Failed to get tensor type and shape"); + return err; + } + + ONNXTensorElementDataType element_type; + status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element type"); + return err; + } + + size_t num_dims; + status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions count"); + return err; + } + + int64_t *dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!dims) { + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions"); + return err; + } + + size_t tensor_size; + status = + ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element count"); + return err; + } + + NN_INFO_PRINTF("Output tensor dimensions: "); + for (size_t i = 0; i < num_dims; i++) { + NN_INFO_PRINTF(" dim[%zu] = %lld", i, dims[i]); + } + NN_INFO_PRINTF("Total elements: %zu", tensor_size); + + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + free(dims); + + if (tensor_size == 0) { + NN_ERR_PRINTF("Tensor is empty (zero elements)"); + return runtime_error; + } + + void *tensor_data = nullptr; + status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); + NN_ERR_PRINTF("Failed to get tensor data"); + return err; + } + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data pointer is null"); + return runtime_error; + } + + size_t element_size; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + element_size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + element_size = sizeof(uint16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + element_size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + element_size = sizeof(int32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + element_size = sizeof(int64_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + element_size = sizeof(uint8_t); + break; + default: + NN_ERR_PRINTF("Unsupported tensor element type: %d", element_type); + return unsupported_operation; + } + + size_t output_size_bytes = tensor_size * element_size; + if (out_buffer->size < output_size_bytes) { + NN_ERR_PRINTF( + "Output buffer too small: %u bytes provided, %zu bytes needed", + out_buffer->size, output_size_bytes); + *out_buffer_size = output_size_bytes; + return too_large; + } + NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, " + "total: %zu bytes", + tensor_size, element_size, output_size_bytes); + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data is null"); + return runtime_error; + } + + if (out_buffer->buf == nullptr) { + NN_ERR_PRINTF("Output buffer is null"); + return invalid_argument; + } + + memcpy(out_buffer->buf, tensor_data, output_size_bytes); + *out_buffer_size = output_size_bytes; + + NN_INFO_PRINTF( + "Output tensor retrieved for context %d, index %d, size %zu bytes", ctx, + index, output_size_bytes); + return success; +} + +} /* End of extern "C" */