17
17
18
18
use crate :: backend:: Id ;
19
19
use crate :: { Backend , Registry } ;
20
+ use anyhow:: anyhow;
20
21
use std:: collections:: HashMap ;
21
22
use std:: hash:: Hash ;
22
23
use std:: { fmt, str:: FromStr } ;
@@ -54,29 +55,57 @@ impl<'a> WasiNnView<'a> {
54
55
}
55
56
}
56
57
57
- pub enum Error {
58
+ /// A wasi-nn error; this appears on the Wasm side as a component model
59
+ /// resource.
60
+ #[ derive( Debug ) ]
61
+ pub struct Error {
62
+ code : ErrorCode ,
63
+ data : anyhow:: Error ,
64
+ }
65
+
66
+ ///
67
+ macro_rules! bail {
68
+ ( $self: ident, $code: expr, $data: expr) => {
69
+ let e = Error {
70
+ code: $code,
71
+ data: $data. into( ) ,
72
+ } ;
73
+ tracing:: error!( "failure: {e:?}" ) ;
74
+ let r = $self. table. push( e) ?;
75
+ return Ok ( Err ( r) ) ;
76
+ } ;
77
+ }
78
+
79
+ impl From < wasmtime:: component:: ResourceTableError > for Error {
80
+ fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
81
+ Self {
82
+ code : ErrorCode :: Trap ,
83
+ data : error. into ( ) ,
84
+ }
85
+ }
86
+ }
87
+
88
+ /// The list of error codes available to the `wasi-nn` API; this should match
89
+ /// what is specified in WIT.
90
+ #[ derive( Debug ) ]
91
+ pub enum ErrorCode {
58
92
/// Caller module passed an invalid argument.
59
93
InvalidArgument ,
60
94
/// Invalid encoding.
61
95
InvalidEncoding ,
62
96
/// The operation timed out.
63
97
Timeout ,
64
- /// Runtime Error .
98
+ /// Runtime error .
65
99
RuntimeError ,
66
100
/// Unsupported operation.
67
101
UnsupportedOperation ,
68
102
/// Graph is too large.
69
103
TooLarge ,
70
104
/// Graph not found.
71
105
NotFound ,
72
- /// A runtime error occurred that we should trap on; see `StreamError`.
73
- Trap ( anyhow:: Error ) ,
74
- }
75
-
76
- impl From < wasmtime:: component:: ResourceTableError > for Error {
77
- fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
78
- Self :: Trap ( error. into ( ) )
79
- }
106
+ /// A runtime error that Wasmtime should trap on; this will not appear in
107
+ /// the WIT specification.
108
+ Trap ,
80
109
}
81
110
82
111
/// Generate the traits and types from the `wasi-nn` WIT specification.
@@ -91,6 +120,7 @@ mod gen_ {
91
120
"wasi:nn/graph/graph" : crate :: Graph ,
92
121
"wasi:nn/tensor/tensor" : crate :: Tensor ,
93
122
"wasi:nn/inference/graph-execution-context" : crate :: ExecutionContext ,
123
+ "wasi:nn/errors/error" : super :: Error ,
94
124
} ,
95
125
trappable_error_type: {
96
126
"wasi:nn/errors/error" => super :: Error ,
@@ -131,36 +161,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131
161
builders : Vec < GraphBuilder > ,
132
162
encoding : GraphEncoding ,
133
163
target : ExecutionTarget ,
134
- ) -> Result < Resource < crate :: Graph > , Error > {
164
+ ) -> Result < Result < Resource < crate :: Graph > , Resource < Error > > , anyhow :: Error > {
135
165
tracing:: debug!( "load {encoding:?} {target:?}" ) ;
136
166
if let Some ( backend) = self . ctx . backends . get_mut ( & encoding) {
137
167
let slices = builders. iter ( ) . map ( |s| s. as_slice ( ) ) . collect :: < Vec < _ > > ( ) ;
138
168
match backend. load ( & slices, target. into ( ) ) {
139
169
Ok ( graph) => {
140
170
let graph = self . table . push ( graph) ?;
141
- Ok ( graph)
171
+ Ok ( Ok ( graph) )
142
172
}
143
173
Err ( error) => {
144
- tracing:: error!( "failed to load graph: {error:?}" ) ;
145
- Err ( Error :: RuntimeError )
174
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
146
175
}
147
176
}
148
177
} else {
149
- Err ( Error :: InvalidEncoding )
178
+ bail ! (
179
+ self ,
180
+ ErrorCode :: InvalidEncoding ,
181
+ anyhow!( "unable to find a backend for this encoding" )
182
+ ) ;
150
183
}
151
184
}
152
185
153
- fn load_by_name ( & mut self , name : String ) -> Result < Resource < Graph > , Error > {
186
+ fn load_by_name (
187
+ & mut self ,
188
+ name : String ,
189
+ ) -> wasmtime:: Result < Result < Resource < Graph > , Resource < Error > > > {
154
190
use core:: result:: Result :: * ;
155
191
tracing:: debug!( "load by name {name:?}" ) ;
156
192
let registry = & self . ctx . registry ;
157
193
if let Some ( graph) = registry. get ( & name) {
158
194
let graph = graph. clone ( ) ;
159
195
let graph = self . table . push ( graph) ?;
160
- Ok ( graph)
196
+ Ok ( Ok ( graph) )
161
197
} else {
162
- tracing:: error!( "failed to find graph with name: {name}" ) ;
163
- Err ( Error :: NotFound )
198
+ bail ! (
199
+ self ,
200
+ ErrorCode :: NotFound ,
201
+ anyhow!( "failed to find graph with name: {name}" )
202
+ ) ;
164
203
}
165
204
}
166
205
}
@@ -169,18 +208,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169
208
fn init_execution_context (
170
209
& mut self ,
171
210
graph : Resource < Graph > ,
172
- ) -> Result < Resource < GraphExecutionContext > , Error > {
211
+ ) -> wasmtime :: Result < Result < Resource < GraphExecutionContext > , Resource < Error > > > {
173
212
use core:: result:: Result :: * ;
174
213
tracing:: debug!( "initialize execution context" ) ;
175
214
let graph = self . table . get ( & graph) ?;
176
215
match graph. init_execution_context ( ) {
177
216
Ok ( exec_context) => {
178
217
let exec_context = self . table . push ( exec_context) ?;
179
- Ok ( exec_context)
218
+ Ok ( Ok ( exec_context) )
180
219
}
181
220
Err ( error) => {
182
- tracing:: error!( "failed to initialize execution context: {error:?}" ) ;
183
- Err ( Error :: RuntimeError )
221
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
184
222
}
185
223
}
186
224
}
@@ -197,47 +235,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197
235
exec_context : Resource < GraphExecutionContext > ,
198
236
name : String ,
199
237
tensor : Resource < Tensor > ,
200
- ) -> Result < ( ) , Error > {
238
+ ) -> wasmtime :: Result < Result < ( ) , Resource < Error > > > {
201
239
let tensor = self . table . get ( & tensor) ?;
202
240
tracing:: debug!( "set input {name:?}: {tensor:?}" ) ;
203
241
let tensor = tensor. clone ( ) ; // TODO: avoid copying the tensor
204
242
let exec_context = self . table . get_mut ( & exec_context) ?;
205
- if let Err ( e) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
206
- tracing:: error!( "failed to set input: {e:?}" ) ;
207
- Err ( Error :: InvalidArgument )
243
+ if let Err ( error) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
244
+ bail ! ( self , ErrorCode :: InvalidArgument , error) ;
208
245
} else {
209
- Ok ( ( ) )
246
+ Ok ( Ok ( ( ) ) )
210
247
}
211
248
}
212
249
213
- fn compute ( & mut self , exec_context : Resource < GraphExecutionContext > ) -> Result < ( ) , Error > {
250
+ fn compute (
251
+ & mut self ,
252
+ exec_context : Resource < GraphExecutionContext > ,
253
+ ) -> wasmtime:: Result < Result < ( ) , Resource < Error > > > {
214
254
let exec_context = & mut self . table . get_mut ( & exec_context) ?;
215
255
tracing:: debug!( "compute" ) ;
216
256
match exec_context. compute ( ) {
217
- Ok ( ( ) ) => Ok ( ( ) ) ,
257
+ Ok ( ( ) ) => Ok ( Ok ( ( ) ) ) ,
218
258
Err ( error) => {
219
- tracing:: error!( "failed to compute: {error:?}" ) ;
220
- Err ( Error :: RuntimeError )
259
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
221
260
}
222
261
}
223
262
}
224
263
225
- #[ doc = r" Extract the outputs after inference." ]
226
264
fn get_output (
227
265
& mut self ,
228
266
exec_context : Resource < GraphExecutionContext > ,
229
267
name : String ,
230
- ) -> Result < Resource < Tensor > , Error > {
268
+ ) -> wasmtime :: Result < Result < Resource < Tensor > , Resource < Error > > > {
231
269
let exec_context = self . table . get_mut ( & exec_context) ?;
232
270
tracing:: debug!( "get output {name:?}" ) ;
233
271
match exec_context. get_output ( Id :: Name ( name) ) {
234
272
Ok ( tensor) => {
235
273
let tensor = self . table . push ( tensor) ?;
236
- Ok ( tensor)
274
+ Ok ( Ok ( tensor) )
237
275
}
238
276
Err ( error) => {
239
- tracing:: error!( "failed to get output: {error:?}" ) ;
240
- Err ( Error :: RuntimeError )
277
+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
241
278
}
242
279
}
243
280
}
@@ -285,21 +322,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285
322
}
286
323
}
287
324
288
- impl gen:: tensor:: Host for WasiNnView < ' _ > { }
325
+ impl gen:: errors:: HostError for WasiNnView < ' _ > {
326
+ fn new (
327
+ & mut self ,
328
+ _code : gen:: errors:: ErrorCode ,
329
+ _data : String ,
330
+ ) -> wasmtime:: Result < Resource < Error > > {
331
+ unimplemented ! ( "this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76" )
332
+ }
333
+
334
+ fn code ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < gen:: errors:: ErrorCode > {
335
+ let error = self . table . get ( & error) ?;
336
+ match error. code {
337
+ ErrorCode :: InvalidArgument => Ok ( gen:: errors:: ErrorCode :: InvalidArgument ) ,
338
+ ErrorCode :: InvalidEncoding => Ok ( gen:: errors:: ErrorCode :: InvalidEncoding ) ,
339
+ ErrorCode :: Timeout => Ok ( gen:: errors:: ErrorCode :: Timeout ) ,
340
+ ErrorCode :: RuntimeError => Ok ( gen:: errors:: ErrorCode :: RuntimeError ) ,
341
+ ErrorCode :: UnsupportedOperation => Ok ( gen:: errors:: ErrorCode :: UnsupportedOperation ) ,
342
+ ErrorCode :: TooLarge => Ok ( gen:: errors:: ErrorCode :: TooLarge ) ,
343
+ ErrorCode :: NotFound => Ok ( gen:: errors:: ErrorCode :: NotFound ) ,
344
+ ErrorCode :: Trap => Err ( anyhow ! ( error. data. to_string( ) ) ) ,
345
+ }
346
+ }
347
+
348
+ fn data ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < String > {
349
+ let error = self . table . get ( & error) ?;
350
+ Ok ( error. data . to_string ( ) )
351
+ }
352
+
353
+ fn drop ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < ( ) > {
354
+ self . table . delete ( error) ?;
355
+ Ok ( ( ) )
356
+ }
357
+ }
358
+
289
359
impl gen:: errors:: Host for WasiNnView < ' _ > {
290
- fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < gen:: errors:: Error > {
291
- match err {
292
- Error :: InvalidArgument => Ok ( gen:: errors:: Error :: InvalidArgument ) ,
293
- Error :: InvalidEncoding => Ok ( gen:: errors:: Error :: InvalidEncoding ) ,
294
- Error :: Timeout => Ok ( gen:: errors:: Error :: Timeout ) ,
295
- Error :: RuntimeError => Ok ( gen:: errors:: Error :: RuntimeError ) ,
296
- Error :: UnsupportedOperation => Ok ( gen:: errors:: Error :: UnsupportedOperation ) ,
297
- Error :: TooLarge => Ok ( gen:: errors:: Error :: TooLarge ) ,
298
- Error :: NotFound => Ok ( gen:: errors:: Error :: NotFound ) ,
299
- Error :: Trap ( e) => Err ( e) ,
360
+ fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < Error > {
361
+ if matches ! ( err. code, ErrorCode :: Trap ) {
362
+ Err ( err. data )
363
+ } else {
364
+ Ok ( err)
300
365
}
301
366
}
302
367
}
368
+
369
+ impl gen:: tensor:: Host for WasiNnView < ' _ > { }
303
370
impl gen:: inference:: Host for WasiNnView < ' _ > { }
304
371
305
372
impl Hash for gen:: graph:: GraphEncoding {
0 commit comments