17
17
18
18
use crate :: backend:: Id ;
19
19
use crate :: { Backend , Registry } ;
20
+ use anyhow:: anyhow;
20
21
use std:: collections:: HashMap ;
21
22
use std:: hash:: Hash ;
22
23
use std:: { fmt, str:: FromStr } ;
@@ -54,7 +55,26 @@ impl<'a> WasiNnView<'a> {
54
55
}
55
56
}
56
57
57
- pub enum Error {
58
+ #[ derive( Debug ) ]
59
+ pub struct Error {
60
+ code : ErrorCode ,
61
+ data : anyhow:: Error ,
62
+ }
63
+
64
+ macro_rules! bail {
65
+ ( $self: ident, $code: expr, $data: expr) => {
66
+ let e = Error {
67
+ code: $code,
68
+ data: $data. into( ) ,
69
+ } ;
70
+ tracing:: error!( "failure: {e:?}" ) ;
71
+ let r = $self. table. push( e) ?;
72
+ return Ok ( Err ( r) ) ;
73
+ } ;
74
+ }
75
+
76
+ #[ derive( Debug ) ]
77
+ pub enum ErrorCode {
58
78
/// Caller module passed an invalid argument.
59
79
InvalidArgument ,
60
80
/// Invalid encoding.
@@ -70,12 +90,15 @@ pub enum Error {
70
90
/// Graph not found.
71
91
NotFound ,
72
92
/// A runtime error occurred that we should trap on; see `StreamError`.
73
- Trap ( anyhow :: Error ) ,
93
+ Trap ,
74
94
}
75
95
76
96
impl From < wasmtime:: component:: ResourceTableError > for Error {
77
97
fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
78
- Self :: Trap ( error. into ( ) )
98
+ Self {
99
+ code : ErrorCode :: Trap ,
100
+ data : error. into ( ) ,
101
+ }
79
102
}
80
103
}
81
104
@@ -91,6 +114,7 @@ mod gen_ {
91
114
"wasi:nn/graph/graph" : crate :: Graph ,
92
115
"wasi:nn/tensor/tensor" : crate :: Tensor ,
93
116
"wasi:nn/inference/graph-execution-context" : crate :: ExecutionContext ,
117
+ "wasi:nn/errors/error" : super :: Error ,
94
118
} ,
95
119
trappable_error_type: {
96
120
"wasi:nn/errors/error" => super :: Error ,
@@ -131,36 +155,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131
155
builders : Vec < GraphBuilder > ,
132
156
encoding : GraphEncoding ,
133
157
target : ExecutionTarget ,
134
- ) -> Result < Resource < crate :: Graph > , Error > {
158
+ ) -> Result < Result < Resource < crate :: Graph > , Resource < Error > > , anyhow :: Error > {
135
159
tracing:: debug!( "load {encoding:?} {target:?}" ) ;
136
160
if let Some ( backend) = self . ctx . backends . get_mut ( & encoding) {
137
161
let slices = builders. iter ( ) . map ( |s| s. as_slice ( ) ) . collect :: < Vec < _ > > ( ) ;
138
162
match backend. load ( & slices, target. into ( ) ) {
139
163
Ok ( graph) => {
140
164
let graph = self . table . push ( graph) ?;
141
- Ok ( graph)
165
+ Ok ( Ok ( graph) )
142
166
}
143
167
Err ( error) => {
144
- tracing:: error!( "failed to load graph: {error:?}" ) ;
145
- Err ( Error :: RuntimeError )
168
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
146
169
}
147
170
}
148
171
} else {
149
- Err ( Error :: InvalidEncoding )
172
+ bail ! (
173
+ self ,
174
+ ErrorCode :: InvalidEncoding ,
175
+ anyhow!( "unable to find a backend for this encoding" )
176
+ ) ;
150
177
}
151
178
}
152
179
153
- fn load_by_name ( & mut self , name : String ) -> Result < Resource < Graph > , Error > {
180
+ fn load_by_name (
181
+ & mut self ,
182
+ name : String ,
183
+ ) -> wasmtime:: Result < Result < Resource < Graph > , Resource < Error > > > {
154
184
use core:: result:: Result :: * ;
155
185
tracing:: debug!( "load by name {name:?}" ) ;
156
186
let registry = & self . ctx . registry ;
157
187
if let Some ( graph) = registry. get ( & name) {
158
188
let graph = graph. clone ( ) ;
159
189
let graph = self . table . push ( graph) ?;
160
- Ok ( graph)
190
+ Ok ( Ok ( graph) )
161
191
} else {
162
- tracing:: error!( "failed to find graph with name: {name}" ) ;
163
- Err ( Error :: NotFound )
192
+ bail ! (
193
+ self ,
194
+ ErrorCode :: NotFound ,
195
+ anyhow!( "failed to find graph with name: {name}" )
196
+ ) ;
164
197
}
165
198
}
166
199
}
@@ -169,18 +202,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169
202
fn init_execution_context (
170
203
& mut self ,
171
204
graph : Resource < Graph > ,
172
- ) -> Result < Resource < GraphExecutionContext > , Error > {
205
+ ) -> wasmtime :: Result < Result < Resource < GraphExecutionContext > , Resource < Error > > > {
173
206
use core:: result:: Result :: * ;
174
207
tracing:: debug!( "initialize execution context" ) ;
175
208
let graph = self . table . get ( & graph) ?;
176
209
match graph. init_execution_context ( ) {
177
210
Ok ( exec_context) => {
178
211
let exec_context = self . table . push ( exec_context) ?;
179
- Ok ( exec_context)
212
+ Ok ( Ok ( exec_context) )
180
213
}
181
214
Err ( error) => {
182
- tracing:: error!( "failed to initialize execution context: {error:?}" ) ;
183
- Err ( Error :: RuntimeError )
215
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
184
216
}
185
217
}
186
218
}
@@ -197,27 +229,28 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197
229
exec_context : Resource < GraphExecutionContext > ,
198
230
name : String ,
199
231
tensor : Resource < Tensor > ,
200
- ) -> Result < ( ) , Error > {
232
+ ) -> wasmtime :: Result < Result < ( ) , Resource < Error > > > {
201
233
let tensor = self . table . get ( & tensor) ?;
202
234
tracing:: debug!( "set input {name:?}: {tensor:?}" ) ;
203
235
let tensor = tensor. clone ( ) ; // TODO: avoid copying the tensor
204
236
let exec_context = self . table . get_mut ( & exec_context) ?;
205
- if let Err ( e) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
206
- tracing:: error!( "failed to set input: {e:?}" ) ;
207
- Err ( Error :: InvalidArgument )
237
+ if let Err ( error) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
238
+ bail ! ( self , ErrorCode :: InvalidArgument , error) ;
208
239
} else {
209
- Ok ( ( ) )
240
+ Ok ( Ok ( ( ) ) )
210
241
}
211
242
}
212
243
213
- fn compute ( & mut self , exec_context : Resource < GraphExecutionContext > ) -> Result < ( ) , Error > {
244
+ fn compute (
245
+ & mut self ,
246
+ exec_context : Resource < GraphExecutionContext > ,
247
+ ) -> wasmtime:: Result < Result < ( ) , Resource < Error > > > {
214
248
let exec_context = & mut self . table . get_mut ( & exec_context) ?;
215
249
tracing:: debug!( "compute" ) ;
216
250
match exec_context. compute ( ) {
217
- Ok ( ( ) ) => Ok ( ( ) ) ,
251
+ Ok ( ( ) ) => Ok ( Ok ( ( ) ) ) ,
218
252
Err ( error) => {
219
- tracing:: error!( "failed to compute: {error:?}" ) ;
220
- Err ( Error :: RuntimeError )
253
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
221
254
}
222
255
}
223
256
}
@@ -227,17 +260,16 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
227
260
& mut self ,
228
261
exec_context : Resource < GraphExecutionContext > ,
229
262
name : String ,
230
- ) -> Result < Resource < Tensor > , Error > {
263
+ ) -> wasmtime :: Result < Result < Resource < Tensor > , Resource < Error > > > {
231
264
let exec_context = self . table . get_mut ( & exec_context) ?;
232
265
tracing:: debug!( "get output {name:?}" ) ;
233
266
match exec_context. get_output ( Id :: Name ( name) ) {
234
267
Ok ( tensor) => {
235
268
let tensor = self . table . push ( tensor) ?;
236
- Ok ( tensor)
269
+ Ok ( Ok ( tensor) )
237
270
}
238
271
Err ( error) => {
239
- tracing:: error!( "failed to get output: {error:?}" ) ;
240
- Err ( Error :: RuntimeError )
272
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
241
273
}
242
274
}
243
275
}
@@ -285,21 +317,50 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285
317
}
286
318
}
287
319
288
- impl gen:: tensor:: Host for WasiNnView < ' _ > { }
320
+ impl gen:: errors:: HostError for WasiNnView < ' _ > {
321
+ fn new (
322
+ & mut self ,
323
+ _code : gen:: errors:: ErrorCode ,
324
+ _data : wasmtime:: component:: __internal:: String ,
325
+ ) -> wasmtime:: Result < wasmtime:: component:: Resource < gen:: errors:: Error > > {
326
+ unimplemented ! ( )
327
+ }
328
+
329
+ fn code ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < gen:: errors:: ErrorCode > {
330
+ let error = self . table . get ( & error) ?;
331
+ match error. code {
332
+ ErrorCode :: InvalidArgument => Ok ( gen:: errors:: ErrorCode :: InvalidArgument ) ,
333
+ ErrorCode :: InvalidEncoding => Ok ( gen:: errors:: ErrorCode :: InvalidEncoding ) ,
334
+ ErrorCode :: Timeout => Ok ( gen:: errors:: ErrorCode :: Timeout ) ,
335
+ ErrorCode :: RuntimeError => Ok ( gen:: errors:: ErrorCode :: RuntimeError ) ,
336
+ ErrorCode :: UnsupportedOperation => Ok ( gen:: errors:: ErrorCode :: UnsupportedOperation ) ,
337
+ ErrorCode :: TooLarge => Ok ( gen:: errors:: ErrorCode :: TooLarge ) ,
338
+ ErrorCode :: NotFound => Ok ( gen:: errors:: ErrorCode :: NotFound ) ,
339
+ ErrorCode :: Trap => Err ( anyhow ! ( error. data. to_string( ) ) ) ,
340
+ }
341
+ }
342
+
343
+ fn data ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < String > {
344
+ let error = self . table . get ( & error) ?;
345
+ Ok ( error. data . to_string ( ) )
346
+ }
347
+
348
+ fn drop ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < ( ) > {
349
+ self . table . delete ( error) ?;
350
+ Ok ( ( ) )
351
+ }
352
+ }
353
+
289
354
impl gen:: errors:: Host for WasiNnView < ' _ > {
290
- fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < gen:: errors:: Error > {
291
- match err {
292
- Error :: InvalidArgument => Ok ( gen:: errors:: Error :: InvalidArgument ) ,
293
- Error :: InvalidEncoding => Ok ( gen:: errors:: Error :: InvalidEncoding ) ,
294
- Error :: Timeout => Ok ( gen:: errors:: Error :: Timeout ) ,
295
- Error :: RuntimeError => Ok ( gen:: errors:: Error :: RuntimeError ) ,
296
- Error :: UnsupportedOperation => Ok ( gen:: errors:: Error :: UnsupportedOperation ) ,
297
- Error :: TooLarge => Ok ( gen:: errors:: Error :: TooLarge ) ,
298
- Error :: NotFound => Ok ( gen:: errors:: Error :: NotFound ) ,
299
- Error :: Trap ( e) => Err ( e) ,
355
+ fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < Error > {
356
+ if matches ! ( err. code, ErrorCode :: Trap ) {
357
+ Err ( err. data )
358
+ } else {
359
+ Ok ( err)
300
360
}
301
361
}
302
362
}
363
+ impl gen:: tensor:: Host for WasiNnView < ' _ > { }
303
364
impl gen:: inference:: Host for WasiNnView < ' _ > { }
304
365
305
366
impl Hash for gen:: graph:: GraphEncoding {
0 commit comments