Skip to content

Commit 6cab94c

Browse files
follow up some review comments
1, type converter btw wasi-nn and onnx runtime returns bool instead of type 2, out_buffer_size does not hold the expected size. 3, onnx runtime does not need calculate input_tenser size.
1 parent 56b6195 commit 6cab94c

File tree

2 files changed

+56
-65
lines changed

2 files changed

+56
-65
lines changed

core/iwasm/libraries/wasi-nn/include/wasi_nn.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
#else
2222
#define WASI_NN_IMPORT(name) \
2323
__attribute__((import_module("wasi_nn"), import_name(name)))
24-
#warning "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
24+
#warning \
25+
"You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
2526
#endif
2627

2728
/**

core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -102,91 +102,81 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status)
102102
return err;
103103
}
104104

105-
static tensor_type
106-
convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type)
105+
static bool
106+
convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, tensor_type *tensor_type)
107107
{
108108
switch (ort_type) {
109109
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
110-
return fp32;
110+
*tensor_type = fp32;
111+
break;
111112
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
112-
return fp16;
113+
*tensor_type = fp16;
114+
break;
113115
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
114116
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
115-
return fp64;
117+
*tensor_type = fp64;
118+
break;
116119
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
117-
return u8;
120+
*tensor_type = u8;
121+
break;
118122
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
119-
return i32;
123+
*tensor_type = i32;
124+
break;
120125
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
121-
return i64;
126+
*tensor_type = i64;
127+
break;
122128
#else
123129
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
124-
return up8;
130+
*tensor_type = up8;
131+
break;
125132
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
126-
return ip32;
133+
*tensor_type = ip32;
134+
break;
127135
#endif
128136
default:
129137
NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type);
130-
return fp32; // Default to fp32
138+
return false;
131139
}
132-
}
133140

134-
static ONNXTensorElementDataType
135-
convert_wasi_nn_type_to_ort_type(tensor_type type)
136-
{
137-
switch (type) {
138-
case fp32:
139-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
140-
case fp16:
141-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
142-
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
143-
case fp64:
144-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
145-
case u8:
146-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
147-
case i32:
148-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
149-
case i64:
150-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
151-
#else
152-
case up8:
153-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
154-
case ip32:
155-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
156-
#endif
157-
default:
158-
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type);
159-
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float
160-
}
141+
return true;
161142
}
162143

163-
static size_t
164-
get_tensor_element_size(tensor_type type)
144+
static bool
145+
convert_wasi_nn_type_to_ort_type(tensor_type type, ONNXTensorElementDataType *ort_type)
165146
{
166147
switch (type) {
167148
case fp32:
168-
return 4;
149+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
150+
break;
169151
case fp16:
170-
return 2;
152+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
153+
break;
171154
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
172155
case fp64:
173-
return 8;
156+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
157+
break;
174158
case u8:
175-
return 1;
159+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
160+
break;
176161
case i32:
177-
return 4;
162+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
163+
break;
178164
case i64:
179-
return 8;
165+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
166+
break;
180167
#else
181168
case up8:
182-
return 1;
169+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
170+
break;
183171
case ip32:
184-
return 4;
172+
*ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
173+
break;
185174
#endif
186175
default:
187-
NN_WARN_PRINTF("Unsupported tensor type: %d", type);
188-
return 4; // Default to 4 bytes (float)
176+
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type);
177+
return false; // Default to float
189178
}
179+
return true;
190180
}
191181

192182
/* Backend API implementation */
@@ -579,8 +569,12 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
579569
ort_dims[i] = input_tensor->dimensions->buf[i];
580570
}
581571

582-
ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type(
583-
static_cast<tensor_type>(input_tensor->type));
572+
ONNXTensorElementDataType ort_type;
573+
if (!convert_wasi_nn_type_to_ort_type(
574+
static_cast<tensor_type>(input_tensor->type), &ort_type)) {
575+
NN_ERR_PRINTF("Failed to convert tensor type");
576+
return runtime_error;
577+
}
584578

585579
OrtValue *input_value = nullptr;
586580
size_t total_elements = 1;
@@ -589,9 +583,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
589583
}
590584

591585
status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue(
592-
exec_ctx->memory_info, input_tensor->data.buf,
593-
get_tensor_element_size(static_cast<tensor_type>(input_tensor->type))
594-
* total_elements,
586+
exec_ctx->memory_info, input_tensor->data.buf,input_tensor->data.size,
595587
ort_dims, num_dims, ort_type, &input_value);
596588

597589
free(ort_dims);
@@ -793,18 +785,16 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
793785
}
794786

795787
size_t output_size_bytes = tensor_size * element_size;
796-
797-
NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, "
798-
"total: %zu bytes",
799-
tensor_size, element_size, output_size_bytes);
800-
801-
if (*out_buffer_size < output_size_bytes) {
788+
if (out_buffer->size < output_size_bytes) {
802789
NN_ERR_PRINTF(
803790
"Output buffer too small: %u bytes provided, %zu bytes needed",
804-
*out_buffer_size, output_size_bytes);
791+
out_buffer->size, output_size_bytes);
805792
*out_buffer_size = output_size_bytes;
806-
return invalid_argument;
793+
return too_large;
807794
}
795+
NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, "
796+
"total: %zu bytes",
797+
tensor_size, element_size, output_size_bytes);
808798

809799
if (tensor_data == nullptr) {
810800
NN_ERR_PRINTF("Tensor data is null");

0 commit comments

Comments
 (0)