Skip to content

Commit ca07911

Browse files
committed
wasi-nn: track upstream specification
In bytecodealliance#8873, we stopped tracking the wasi-nn's upstream WIT files temporarily because it was not clear to me at the time how to implement errors as CM resources. This PR fixes that, resuming tracking in the `vendor-wit.sh` and implementing what is needed in the wasi-nn crate. This leaves several threads unresolved, though: - it looks like the `vendor-wit.sh` has a new mechanism for retrieving files into `wit/deps`--at some point wasi-nn should migrate to use that (?) - it's not clear to me that "errors as resources" is even the best approach here; I've opened [bytecodealliance#75] to consider the possibility of using "errors as records" instead. - this PR identifies that the constructor for errors is in fact unnecessary, prompting an upstream change ([bytecodealliance#76]) that should eventually be implemented here. [bytecodealliance#75]: WebAssembly/wasi-nn#75 [bytecodealliance#76]: WebAssembly/wasi-nn#76 prtest:full
1 parent ba864e9 commit ca07911

File tree

3 files changed

+135
-57
lines changed

3 files changed

+135
-57
lines changed

ci/vendor-wit.sh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ rm -rf $cache_dir
6565
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
6666
# slightly different than above.
6767
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
68-
revision=e2310b
68+
revision=0.2.0-rc-2024-06-25
6969
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
70-
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
71-
# latest WIT specification on `main`. To create a baseline for moving forward,
72-
# the in-tree WIT incorporates some but not all of the upstream changes. This
73-
# TODO can be removed once the implementation catches up with the spec.
74-
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
70+
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit

crates/wasi-nn/src/wit.rs

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
use crate::backend::Id;
1919
use crate::{Backend, Registry};
20+
use anyhow::anyhow;
2021
use std::collections::HashMap;
2122
use std::hash::Hash;
2223
use std::{fmt, str::FromStr};
@@ -54,29 +55,57 @@ impl<'a> WasiNnView<'a> {
5455
}
5556
}
5657

57-
pub enum Error {
58+
/// A wasi-nn error; this appears on the Wasm side as a component model
59+
/// resource.
60+
#[derive(Debug)]
61+
pub struct Error {
62+
code: ErrorCode,
63+
data: anyhow::Error,
64+
}
65+
66+
///
67+
macro_rules! bail {
68+
($self:ident, $code:expr, $data:expr) => {
69+
let e = Error {
70+
code: $code,
71+
data: $data.into(),
72+
};
73+
tracing::error!("failure: {e:?}");
74+
let r = $self.table.push(e)?;
75+
return Ok(Err(r));
76+
};
77+
}
78+
79+
impl From<wasmtime::component::ResourceTableError> for Error {
80+
fn from(error: wasmtime::component::ResourceTableError) -> Self {
81+
Self {
82+
code: ErrorCode::Trap,
83+
data: error.into(),
84+
}
85+
}
86+
}
87+
88+
/// The list of error codes available to the `wasi-nn` API; this should match
89+
/// what is specified in WIT.
90+
#[derive(Debug)]
91+
pub enum ErrorCode {
5892
/// Caller module passed an invalid argument.
5993
InvalidArgument,
6094
/// Invalid encoding.
6195
InvalidEncoding,
6296
/// The operation timed out.
6397
Timeout,
64-
/// Runtime Error.
98+
/// Runtime error.
6599
RuntimeError,
66100
/// Unsupported operation.
67101
UnsupportedOperation,
68102
/// Graph is too large.
69103
TooLarge,
70104
/// Graph not found.
71105
NotFound,
72-
/// A runtime error occurred that we should trap on; see `StreamError`.
73-
Trap(anyhow::Error),
74-
}
75-
76-
impl From<wasmtime::component::ResourceTableError> for Error {
77-
fn from(error: wasmtime::component::ResourceTableError) -> Self {
78-
Self::Trap(error.into())
79-
}
106+
/// A runtime error that Wasmtime should trap on; this will not appear in
107+
/// the WIT specification.
108+
Trap,
80109
}
81110

82111
/// Generate the traits and types from the `wasi-nn` WIT specification.
@@ -91,6 +120,7 @@ mod gen_ {
91120
"wasi:nn/graph/graph": crate::Graph,
92121
"wasi:nn/tensor/tensor": crate::Tensor,
93122
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
123+
"wasi:nn/errors/error": super::Error,
94124
},
95125
trappable_error_type: {
96126
"wasi:nn/errors/error" => super::Error,
@@ -131,36 +161,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131161
builders: Vec<GraphBuilder>,
132162
encoding: GraphEncoding,
133163
target: ExecutionTarget,
134-
) -> Result<Resource<crate::Graph>, Error> {
164+
) -> Result<Result<Resource<crate::Graph>, Resource<Error>>, anyhow::Error> {
135165
tracing::debug!("load {encoding:?} {target:?}");
136166
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
137167
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
138168
match backend.load(&slices, target.into()) {
139169
Ok(graph) => {
140170
let graph = self.table.push(graph)?;
141-
Ok(graph)
171+
Ok(Ok(graph))
142172
}
143173
Err(error) => {
144-
tracing::error!("failed to load graph: {error:?}");
145-
Err(Error::RuntimeError)
174+
bail!(self, ErrorCode::RuntimeError, error);
146175
}
147176
}
148177
} else {
149-
Err(Error::InvalidEncoding)
178+
bail!(
179+
self,
180+
ErrorCode::InvalidEncoding,
181+
anyhow!("unable to find a backend for this encoding")
182+
);
150183
}
151184
}
152185

153-
fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
186+
fn load_by_name(
187+
&mut self,
188+
name: String,
189+
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
154190
use core::result::Result::*;
155191
tracing::debug!("load by name {name:?}");
156192
let registry = &self.ctx.registry;
157193
if let Some(graph) = registry.get(&name) {
158194
let graph = graph.clone();
159195
let graph = self.table.push(graph)?;
160-
Ok(graph)
196+
Ok(Ok(graph))
161197
} else {
162-
tracing::error!("failed to find graph with name: {name}");
163-
Err(Error::NotFound)
198+
bail!(
199+
self,
200+
ErrorCode::NotFound,
201+
anyhow!("failed to find graph with name: {name}")
202+
);
164203
}
165204
}
166205
}
@@ -169,18 +208,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169208
fn init_execution_context(
170209
&mut self,
171210
graph: Resource<Graph>,
172-
) -> Result<Resource<GraphExecutionContext>, Error> {
211+
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
173212
use core::result::Result::*;
174213
tracing::debug!("initialize execution context");
175214
let graph = self.table.get(&graph)?;
176215
match graph.init_execution_context() {
177216
Ok(exec_context) => {
178217
let exec_context = self.table.push(exec_context)?;
179-
Ok(exec_context)
218+
Ok(Ok(exec_context))
180219
}
181220
Err(error) => {
182-
tracing::error!("failed to initialize execution context: {error:?}");
183-
Err(Error::RuntimeError)
221+
bail!(self, ErrorCode::RuntimeError, error);
184222
}
185223
}
186224
}
@@ -197,47 +235,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197235
exec_context: Resource<GraphExecutionContext>,
198236
name: String,
199237
tensor: Resource<Tensor>,
200-
) -> Result<(), Error> {
238+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
201239
let tensor = self.table.get(&tensor)?;
202240
tracing::debug!("set input {name:?}: {tensor:?}");
203241
let tensor = tensor.clone(); // TODO: avoid copying the tensor
204242
let exec_context = self.table.get_mut(&exec_context)?;
205-
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
206-
tracing::error!("failed to set input: {e:?}");
207-
Err(Error::InvalidArgument)
243+
if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
244+
bail!(self, ErrorCode::InvalidArgument, error);
208245
} else {
209-
Ok(())
246+
Ok(Ok(()))
210247
}
211248
}
212249

213-
fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
250+
fn compute(
251+
&mut self,
252+
exec_context: Resource<GraphExecutionContext>,
253+
) -> wasmtime::Result<Result<(), Resource<Error>>> {
214254
let exec_context = &mut self.table.get_mut(&exec_context)?;
215255
tracing::debug!("compute");
216256
match exec_context.compute() {
217-
Ok(()) => Ok(()),
257+
Ok(()) => Ok(Ok(())),
218258
Err(error) => {
219-
tracing::error!("failed to compute: {error:?}");
220-
Err(Error::RuntimeError)
259+
bail!(self, ErrorCode::RuntimeError, error);
221260
}
222261
}
223262
}
224263

225-
#[doc = r" Extract the outputs after inference."]
226264
fn get_output(
227265
&mut self,
228266
exec_context: Resource<GraphExecutionContext>,
229267
name: String,
230-
) -> Result<Resource<Tensor>, Error> {
268+
) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
231269
let exec_context = self.table.get_mut(&exec_context)?;
232270
tracing::debug!("get output {name:?}");
233271
match exec_context.get_output(Id::Name(name)) {
234272
Ok(tensor) => {
235273
let tensor = self.table.push(tensor)?;
236-
Ok(tensor)
274+
Ok(Ok(tensor))
237275
}
238276
Err(error) => {
239-
tracing::error!("failed to get output: {error:?}");
240-
Err(Error::RuntimeError)
277+
bail!(self, ErrorCode::RuntimeError, error);
241278
}
242279
}
243280
}
@@ -285,21 +322,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285322
}
286323
}
287324

288-
impl gen::tensor::Host for WasiNnView<'_> {}
325+
impl gen::errors::HostError for WasiNnView<'_> {
326+
fn new(
327+
&mut self,
328+
_code: gen::errors::ErrorCode,
329+
_data: String,
330+
) -> wasmtime::Result<Resource<Error>> {
331+
unimplemented!("this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76")
332+
}
333+
334+
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<gen::errors::ErrorCode> {
335+
let error = self.table.get(&error)?;
336+
match error.code {
337+
ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument),
338+
ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding),
339+
ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout),
340+
ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError),
341+
ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation),
342+
ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge),
343+
ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound),
344+
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
345+
}
346+
}
347+
348+
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
349+
let error = self.table.get(&error)?;
350+
Ok(error.data.to_string())
351+
}
352+
353+
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
354+
self.table.delete(error)?;
355+
Ok(())
356+
}
357+
}
358+
289359
impl gen::errors::Host for WasiNnView<'_> {
290-
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
291-
match err {
292-
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
293-
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
294-
Error::Timeout => Ok(gen::errors::Error::Timeout),
295-
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
296-
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
297-
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
298-
Error::NotFound => Ok(gen::errors::Error::NotFound),
299-
Error::Trap(e) => Err(e),
360+
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
361+
if matches!(err.code, ErrorCode::Trap) {
362+
Err(err.data)
363+
} else {
364+
Ok(err)
300365
}
301366
}
302367
}
368+
369+
impl gen::tensor::Host for WasiNnView<'_> {}
303370
impl gen::inference::Host for WasiNnView<'_> {}
304371

305372
impl Hash for gen::graph::GraphEncoding {

crates/wasi-nn/wit/wasi-nn.wit

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package wasi:nn;
1+
package wasi:nn@0.2.0-rc-2024-06-25;
22

33
/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
44
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
@@ -134,7 +134,7 @@ interface inference {
134134

135135
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
136136
interface errors {
137-
enum error {
137+
enum error-code {
138138
// Caller module passed an invalid argument.
139139
invalid-argument,
140140
// Invalid encoding.
@@ -148,6 +148,21 @@ interface errors {
148148
// Graph is too large.
149149
too-large,
150150
// Graph not found.
151-
not-found
151+
not-found,
152+
// The operation is insecure or has insufficient privilege to be performed.
153+
// e.g., cannot access a hardware feature requested
154+
security,
155+
// The operation failed for an unspecified reason.
156+
unknown
157+
}
158+
159+
resource error {
160+
constructor(code: error-code, data: string);
161+
162+
/// Return the error code.
163+
code: func() -> error-code;
164+
165+
/// Errors can propagated with backend specific status through a string value.
166+
data: func() -> string;
152167
}
153168
}

0 commit comments

Comments
 (0)