@@ -102,91 +102,81 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status)
102
102
return err;
103
103
}
104
104
105
- static tensor_type
106
- convert_ort_type_to_wasi_nn_type (ONNXTensorElementDataType ort_type)
105
+ static bool
106
+ convert_ort_type_to_wasi_nn_type (ONNXTensorElementDataType ort_type, tensor_type *tensor_type )
107
107
{
108
108
switch (ort_type) {
109
109
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
110
- return fp32;
110
+ *tensor_type = fp32;
111
+ break ;
111
112
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
112
- return fp16;
113
+ *tensor_type = fp16;
114
+ break ;
113
115
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
114
116
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
115
- return fp64;
117
+ *tensor_type = fp64;
118
+ break ;
116
119
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
117
- return u8 ;
120
+ *tensor_type = u8 ;
121
+ break ;
118
122
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
119
- return i32 ;
123
+ *tensor_type = i32 ;
124
+ break ;
120
125
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
121
- return i64 ;
126
+ *tensor_type = i64 ;
127
+ break ;
122
128
#else
123
129
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
124
- return up8;
130
+ *tensor_type = up8;
131
+ break ;
125
132
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
126
- return ip32;
133
+ *tensor_type = ip32;
134
+ break ;
127
135
#endif
128
136
default :
129
137
NN_WARN_PRINTF (" Unsupported ONNX tensor type: %d" , ort_type);
130
- return fp32; // Default to fp32
138
+ return false ;
131
139
}
132
- }
133
140
134
- static ONNXTensorElementDataType
135
- convert_wasi_nn_type_to_ort_type (tensor_type type)
136
- {
137
- switch (type) {
138
- case fp32:
139
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
140
- case fp16:
141
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
142
- #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
143
- case fp64:
144
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
145
- case u8 :
146
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
147
- case i32 :
148
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
149
- case i64 :
150
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
151
- #else
152
- case up8:
153
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
154
- case ip32:
155
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
156
- #endif
157
- default :
158
- NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
159
- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float
160
- }
141
+ return true ;
161
142
}
162
143
163
- static size_t
164
- get_tensor_element_size (tensor_type type)
144
+ static bool
145
+ convert_wasi_nn_type_to_ort_type (tensor_type type, ONNXTensorElementDataType *ort_type )
165
146
{
166
147
switch (type) {
167
148
case fp32:
168
- return 4 ;
149
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
150
+ break ;
169
151
case fp16:
170
- return 2 ;
152
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
153
+ break ;
171
154
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
172
155
case fp64:
173
- return 8 ;
156
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
157
+ break ;
174
158
case u8 :
175
- return 1 ;
159
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
160
+ break ;
176
161
case i32 :
177
- return 4 ;
162
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
163
+ break ;
178
164
case i64 :
179
- return 8 ;
165
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
166
+ break ;
180
167
#else
181
168
case up8:
182
- return 1 ;
169
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
170
+ break ;
183
171
case ip32:
184
- return 4 ;
172
+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
173
+ break ;
185
174
#endif
186
175
default :
187
- NN_WARN_PRINTF (" Unsupported tensor type: %d" , type);
188
- return 4 ; // Default to 4 bytes ( float)
176
+ NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
177
+ return false ; // Default to float
189
178
}
179
+ return true ;
190
180
}
191
181
192
182
/* Backend API implementation */
@@ -579,8 +569,12 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
579
569
ort_dims[i] = input_tensor->dimensions ->buf [i];
580
570
}
581
571
582
- ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type (
583
- static_cast <tensor_type>(input_tensor->type ));
572
+ ONNXTensorElementDataType ort_type;
573
+ if (!convert_wasi_nn_type_to_ort_type (
574
+ static_cast <tensor_type>(input_tensor->type ), &ort_type)) {
575
+ NN_ERR_PRINTF (" Failed to convert tensor type" );
576
+ return runtime_error;
577
+ }
584
578
585
579
OrtValue *input_value = nullptr ;
586
580
size_t total_elements = 1 ;
@@ -589,9 +583,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
589
583
}
590
584
591
585
status = ort_ctx->ort_api ->CreateTensorWithDataAsOrtValue (
592
- exec_ctx->memory_info , input_tensor->data .buf ,
593
- get_tensor_element_size (static_cast <tensor_type>(input_tensor->type ))
594
- * total_elements,
586
+ exec_ctx->memory_info , input_tensor->data .buf ,input_tensor->data .size ,
595
587
ort_dims, num_dims, ort_type, &input_value);
596
588
597
589
free (ort_dims);
@@ -793,18 +785,16 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
793
785
}
794
786
795
787
size_t output_size_bytes = tensor_size * element_size;
796
-
797
- NN_INFO_PRINTF (" Output tensor size: %zu elements, element size: %zu bytes, "
798
- " total: %zu bytes" ,
799
- tensor_size, element_size, output_size_bytes);
800
-
801
- if (*out_buffer_size < output_size_bytes) {
788
+ if (out_buffer->size < output_size_bytes) {
802
789
NN_ERR_PRINTF (
803
790
" Output buffer too small: %u bytes provided, %zu bytes needed" ,
804
- *out_buffer_size , output_size_bytes);
791
+ out_buffer-> size , output_size_bytes);
805
792
*out_buffer_size = output_size_bytes;
806
- return invalid_argument ;
793
+ return too_large ;
807
794
}
795
+ NN_INFO_PRINTF (" Output tensor size: %zu elements, element size: %zu bytes, "
796
+ " total: %zu bytes" ,
797
+ tensor_size, element_size, output_size_bytes);
808
798
809
799
if (tensor_data == nullptr ) {
810
800
NN_ERR_PRINTF (" Tensor data is null" );
0 commit comments