Skip to content

Commit a8ea901

Browse files
committed
wasi-nn: track upstream specification
In bytecodealliance#8873, we stopped tracking the wasi-nn's upstream WIT files temporarily because it was not clear to me at the time how to implement errors as CM resources. This PR fixes that, resuming tracking in the `vendor-wit.sh` and implementing what is needed in the wasi-nn crate. This leaves several threads unresolved, though: - it looks like the `vendor-wit.sh` has a new mechanism for retrieving files into `wit/deps`--at some point wasi-nn should migrate to use that (?) - it's not clear to me that "errors as resources" is even the best approach here; I've opened [bytecodealliance#75] to consider the possibility of using "errors as records" instead. [bytecodealliance#75]: WebAssembly/wasi-nn#75
1 parent ba864e9 commit a8ea901

File tree

4 files changed

+123
-49
lines changed

4 files changed

+123
-49
lines changed

ci/vendor-wit.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ rm -rf $cache_dir
6565
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
6666
# slightly different than above.
6767
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
68-
revision=e2310b
68+
revision=0.2.0-rc-2024-06-25
6969
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
70-
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
71-
# latest WIT specification on `main`. To create a baseline for moving forward,
72-
# the in-tree WIT incorporates some but not all of the upstream changes. This
73-
# TODO can be removed once the implementation catches up with the spec.
74-
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
70+
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit

crates/wasi-nn/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ impl std::ops::DerefMut for ExecutionContext {
103103
}
104104
}
105105

106+
107+
106108
/// A container for graphs.
107109
pub struct Registry(Box<dyn GraphRegistry>);
108110
impl std::ops::Deref for Registry {

crates/wasi-nn/src/wit.rs

Lines changed: 101 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
use crate::backend::Id;
1919
use crate::{Backend, Registry};
20+
use anyhow::anyhow;
2021
use std::collections::HashMap;
2122
use std::hash::Hash;
2223
use std::{fmt, str::FromStr};
@@ -54,7 +55,26 @@ impl<'a> WasiNnView<'a> {
5455
}
5556
}
5657

57-
pub enum Error {
58+
#[derive(Debug)]
59+
pub struct Error {
60+
code: ErrorCode,
61+
data: anyhow::Error,
62+
}
63+
64+
macro_rules! bail {
65+
($self:ident, $code:expr, $data:expr) => {
66+
let e = Error {
67+
code: $code,
68+
data: $data.into(),
69+
};
70+
tracing::error!("failure: {e:?}");
71+
let r = $self.table.push(e)?;
72+
return Ok(Err(r));
73+
};
74+
}
75+
76+
#[derive(Debug)]
77+
pub enum ErrorCode {
5878
/// Caller module passed an invalid argument.
5979
InvalidArgument,
6080
/// Invalid encoding.
@@ -70,12 +90,15 @@ pub enum Error {
7090
/// Graph not found.
7191
NotFound,
7292
/// A runtime error occurred that we should trap on; see `StreamError`.
73-
Trap(anyhow::Error),
93+
Trap,
7494
}
7595

7696
impl From<wasmtime::component::ResourceTableError> for Error {
7797
fn from(error: wasmtime::component::ResourceTableError) -> Self {
78-
Self::Trap(error.into())
98+
Self {
99+
code: ErrorCode::Trap,
100+
data: error.into(),
101+
}
79102
}
80103
}
81104

@@ -91,6 +114,7 @@ mod gen_ {
91114
"wasi:nn/graph/graph": crate::Graph,
92115
"wasi:nn/tensor/tensor": crate::Tensor,
93116
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
117+
"wasi:nn/errors/error": super::Error,
94118
},
95119
trappable_error_type: {
96120
"wasi:nn/errors/error" => super::Error,
@@ -131,36 +155,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131155
builders: Vec<GraphBuilder>,
132156
encoding: GraphEncoding,
133157
target: ExecutionTarget,
134-
) -> Result<Resource<crate::Graph>, Error> {
158+
) -> Result<Result<Resource<crate::Graph>, Resource<Error>>, anyhow::Error> {
135159
tracing::debug!("load {encoding:?} {target:?}");
136160
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
137161
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
138162
match backend.load(&slices, target.into()) {
139163
Ok(graph) => {
140164
let graph = self.table.push(graph)?;
141-
Ok(graph)
165+
Ok(Ok(graph))
142166
}
143167
Err(error) => {
144-
tracing::error!("failed to load graph: {error:?}");
145-
Err(Error::RuntimeError)
168+
bail!(self, ErrorCode::RuntimeError, error);
146169
}
147170
}
148171
} else {
149-
Err(Error::InvalidEncoding)
172+
bail!(
173+
self,
174+
ErrorCode::InvalidEncoding,
175+
anyhow!("unable to find a backend for this encoding")
176+
);
150177
}
151178
}
152179

153-
fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
180+
fn load_by_name(
181+
&mut self,
182+
name: String,
183+
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
154184
use core::result::Result::*;
155185
tracing::debug!("load by name {name:?}");
156186
let registry = &self.ctx.registry;
157187
if let Some(graph) = registry.get(&name) {
158188
let graph = graph.clone();
159189
let graph = self.table.push(graph)?;
160-
Ok(graph)
190+
Ok(Ok(graph))
161191
} else {
162-
tracing::error!("failed to find graph with name: {name}");
163-
Err(Error::NotFound)
192+
bail!(
193+
self,
194+
ErrorCode::NotFound,
195+
anyhow!("failed to find graph with name: {name}")
196+
);
164197
}
165198
}
166199
}
@@ -169,18 +202,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169202
fn init_execution_context(
170203
&mut self,
171204
graph: Resource<Graph>,
172-
) -> Result<Resource<GraphExecutionContext>, Error> {
205+
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
173206
use core::result::Result::*;
174207
tracing::debug!("initialize execution context");
175208
let graph = self.table.get(&graph)?;
176209
match graph.init_execution_context() {
177210
Ok(exec_context) => {
178211
let exec_context = self.table.push(exec_context)?;
179-
Ok(exec_context)
212+
Ok(Ok(exec_context))
180213
}
181214
Err(error) => {
182-
tracing::error!("failed to initialize execution context: {error:?}");
183-
Err(Error::RuntimeError)
215+
bail!(self, ErrorCode::RuntimeError, error);
184216
}
185217
}
186218
}
@@ -197,27 +229,28 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197229
exec_context: Resource<GraphExecutionContext>,
198230
name: String,
199231
tensor: Resource<Tensor>,
200-
) -> Result<(), Error> {
232+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
201233
let tensor = self.table.get(&tensor)?;
202234
tracing::debug!("set input {name:?}: {tensor:?}");
203235
let tensor = tensor.clone(); // TODO: avoid copying the tensor
204236
let exec_context = self.table.get_mut(&exec_context)?;
205-
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
206-
tracing::error!("failed to set input: {e:?}");
207-
Err(Error::InvalidArgument)
237+
if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
238+
bail!(self, ErrorCode::InvalidArgument, error);
208239
} else {
209-
Ok(())
240+
Ok(Ok(()))
210241
}
211242
}
212243

213-
fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
244+
fn compute(
245+
&mut self,
246+
exec_context: Resource<GraphExecutionContext>,
247+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
214248
let exec_context = &mut self.table.get_mut(&exec_context)?;
215249
tracing::debug!("compute");
216250
match exec_context.compute() {
217-
Ok(()) => Ok(()),
251+
Ok(()) => Ok(Ok(())),
218252
Err(error) => {
219-
tracing::error!("failed to compute: {error:?}");
220-
Err(Error::RuntimeError)
253+
bail!(self, ErrorCode::RuntimeError, error);
221254
}
222255
}
223256
}
@@ -227,17 +260,16 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
227260
&mut self,
228261
exec_context: Resource<GraphExecutionContext>,
229262
name: String,
230-
) -> Result<Resource<Tensor>, Error> {
263+
) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
231264
let exec_context = self.table.get_mut(&exec_context)?;
232265
tracing::debug!("get output {name:?}");
233266
match exec_context.get_output(Id::Name(name)) {
234267
Ok(tensor) => {
235268
let tensor = self.table.push(tensor)?;
236-
Ok(tensor)
269+
Ok(Ok(tensor))
237270
}
238271
Err(error) => {
239-
tracing::error!("failed to get output: {error:?}");
240-
Err(Error::RuntimeError)
272+
bail!(self, ErrorCode::RuntimeError, error);
241273
}
242274
}
243275
}
@@ -285,21 +317,50 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285317
}
286318
}
287319

288-
impl gen::tensor::Host for WasiNnView<'_> {}
320+
impl gen::errors::HostError for WasiNnView<'_> {
321+
fn new(
322+
&mut self,
323+
_code: gen::errors::ErrorCode,
324+
_data: wasmtime::component::__internal::String,
325+
) -> wasmtime::Result<wasmtime::component::Resource<gen::errors::Error>> {
326+
unimplemented!()
327+
}
328+
329+
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<gen::errors::ErrorCode> {
330+
let error = self.table.get(&error)?;
331+
match error.code {
332+
ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument),
333+
ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding),
334+
ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout),
335+
ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError),
336+
ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation),
337+
ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge),
338+
ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound),
339+
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
340+
}
341+
}
342+
343+
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
344+
let error = self.table.get(&error)?;
345+
Ok(error.data.to_string())
346+
}
347+
348+
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
349+
self.table.delete(error)?;
350+
Ok(())
351+
}
352+
}
353+
289354
impl gen::errors::Host for WasiNnView<'_> {
290-
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
291-
match err {
292-
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
293-
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
294-
Error::Timeout => Ok(gen::errors::Error::Timeout),
295-
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
296-
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
297-
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
298-
Error::NotFound => Ok(gen::errors::Error::NotFound),
299-
Error::Trap(e) => Err(e),
355+
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
356+
if matches!(err.code, ErrorCode::Trap) {
357+
Err(err.data)
358+
} else {
359+
Ok(err)
300360
}
301361
}
302362
}
363+
impl gen::tensor::Host for WasiNnView<'_> {}
303364
impl gen::inference::Host for WasiNnView<'_> {}
304365

305366
impl Hash for gen::graph::GraphEncoding {

crates/wasi-nn/wit/wasi-nn.wit

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package wasi:nn;
1+
package wasi:nn@0.2.0-rc-2024-06-25;
22

33
/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
44
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
@@ -134,7 +134,7 @@ interface inference {
134134

135135
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
136136
interface errors {
137-
enum error {
137+
enum error-code {
138138
// Caller module passed an invalid argument.
139139
invalid-argument,
140140
// Invalid encoding.
@@ -148,6 +148,21 @@ interface errors {
148148
// Graph is too large.
149149
too-large,
150150
// Graph not found.
151-
not-found
151+
not-found,
152+
// The operation is insecure or has insufficient privilege to be performed.
153+
// e.g., cannot access a hardware feature requested
154+
security,
155+
// The operation failed for an unspecified reason.
156+
unknown
157+
}
158+
159+
resource error {
160+
constructor(code: error-code, data: string);
161+
162+
/// Return the error code.
163+
code: func() -> error-code;
164+
165+
/// Errors can propagated with backend specific status through a string value.
166+
data: func() -> string;
152167
}
153168
}

0 commit comments

Comments
 (0)