Skip to content

Commit cf7127c

Browse files
authored
update snapshot to match wasmtime-wasi-nn v36 (#154)
1 parent 72acb40 commit cf7127c

File tree

4 files changed

+199
-196
lines changed

4 files changed

+199
-196
lines changed

samples/wasm/rust/examples/snapshot/Cargo.lock

Lines changed: 21 additions & 173 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

samples/wasm/rust/examples/snapshot/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ edition = "2021"
77

88
[dependencies]
99
ndarray = "0.16.1"
10-
wit-bindgen = "0.22"
10+
wit-bindgen = "0.32"
1111
wasm_graph_sdk = { version = "=1.1.1", registry = "aio-sdks" }
1212
serde = { version = "1", features = [
1313
"derive",
@@ -16,7 +16,6 @@ serde = { version = "1", features = [
1616
serde_json = { version = "1.0", default-features = false, features = [
1717
"alloc",
1818
] }
19-
wasi-nn = {git = "https://github.com/bytecodealliance/wasi-nn.git"}
2019

2120
[lib]
2221
crate-type = ["cdylib"]

samples/wasm/rust/examples/snapshot/src/lib.rs

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
// Generated by `wit_bindgen::generate` expansion.
55
#![allow(clippy::missing_safety_doc)]
66

7+
wit_bindgen::generate!({
8+
path: "wit",
9+
world: "ml",
10+
});
11+
712
mod map_snapshot {
813
use std::sync::LazyLock;
914

@@ -23,7 +28,7 @@ mod map_snapshot {
2328
use ndarray::{Array, Dim};
2429
use std::io::BufRead;
2530

26-
use wasi_nn::{
31+
use crate::wasi::nn::{
2732
graph::{load, ExecutionTarget, Graph, GraphEncoding, GraphExecutionContext},
2833
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
2934
};
@@ -76,21 +81,14 @@ mod map_snapshot {
7681
let tensor = Tensor::new(&dimensions, TensorType::Fp32, &data);
7782

7883
unsafe {
79-
match (*exec_context).set_input("data", tensor) {
80-
Ok(()) => {}
81-
Err(e) => {
82-
logger::log(
83-
Level::Error,
84-
"module-snapshot/map",
85-
&format!("Error setting input tensor: {e:?}"),
86-
);
87-
panic!("Error setting input tensor: {:?}", e);
88-
}
89-
}
90-
9184
// Execute the inferencing
92-
match (*exec_context).compute() {
93-
Ok(()) => {}
85+
let output_data = match (*exec_context).compute(vec![("data".to_owned(), tensor)]) {
86+
Ok(result) => result
87+
.iter()
88+
.find(|p| p.0 == MODEL_OUTPUT_NAME)
89+
.unwrap()
90+
.1
91+
.data(),
9492
Err(e) => {
9593
logger::log(
9694
Level::Error,
@@ -99,13 +97,8 @@ mod map_snapshot {
9997
);
10098
panic!("Error executing graph inference: {e:?}");
10199
}
102-
}
100+
};
103101

104-
// Get the inferencing result (bytes) and convert it to f32
105-
let output_data = (*exec_context)
106-
.get_output(MODEL_OUTPUT_NAME)
107-
.unwrap()
108-
.data();
109102
let output_f32 = bytes_to_f32_vec(&output_data);
110103

111104
let output_shape = [1, 1000, 1, 1];

0 commit comments

Comments
 (0)