Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ wiggle = { workspace = true }
# These dependencies are necessary for the wasi-nn implementation:
openvino = { version = "0.5.0", features = ["runtime-linking"] }
thiserror = { workspace = true }
dashmap = "5.4.0"

[build-dependencies]
walkdir = { workspace = true }
2 changes: 1 addition & 1 deletion crates/wasi-nn/spec
Submodule spec updated 120 files
6 changes: 6 additions & 0 deletions crates/wasi-nn/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ pub(crate) trait Backend: Send + Sync {
builders: &GraphBuilderArray<'_>,
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;

fn load_from_bytes(
&mut self,
model_bytes: &Vec<Vec<u8>>,
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
Expand Down
17 changes: 16 additions & 1 deletion crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::openvino::OpenvinoBackend;
use crate::r#impl::UsageError;
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
use crate::witx::types::{ExecutionTarget, Graph, GraphEncoding, GraphExecutionContext};
use dashmap::DashMap;
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
Expand All @@ -14,6 +15,18 @@ pub struct WasiNnCtx {
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
pub(crate) model_registry: DashMap<String, RegisteredModel>,
pub(crate) loaded_models: DashMap<String, LoadedModel>,
}

pub(crate) struct RegisteredModel {
pub(crate) model_bytes: Vec<Vec<u8>>,
pub(crate) encoding: GraphEncoding,
pub(crate) target: ExecutionTarget,
}

pub(crate) struct LoadedModel {
pub(crate) graph: Graph,
}

impl WasiNnCtx {
Expand All @@ -30,6 +43,8 @@ impl WasiNnCtx {
backends,
graphs: Table::default(),
executions: Table::default(),
model_registry: DashMap::new(),
loaded_models: DashMap::new(),
})
}
}
Expand Down
173 changes: 170 additions & 3 deletions crates/wasi-nn/src/impl.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
//! Implements the wasi-nn API.
use crate::ctx::WasiNnResult as Result;

use crate::ctx::{LoadedModel, RegisteredModel, WasiNnResult as Result};
use crate::witx::types::{
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor,
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext,
Tensor,
};
use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn;
use crate::WasiNnCtx;
use thiserror::Error;
use wiggle::GuestPtr;

const MAX_GUEST_MODEL_REGISTRATION_SIZE: usize = 20 * 1024 * 1024; //20M

#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
Expand All @@ -22,14 +26,35 @@ pub enum UsageError {
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
#[error("Model size {0} exceeds allowed quota of {1}")]
ModelTooLarge(usize, usize),
}

impl WasiNnCtx {
fn build_graph(
&mut self,
model_bytes: &Vec<Vec<u8>>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
let encoding_id: u8 = encoding.into();
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
backend.load_from_bytes(model_bytes, target)?
} else {
return Err(UsageError::InvalidEncoding(encoding).into());
};

let graph_id = self.graphs.insert(graph);
Ok(graph_id)
}
}

impl<'a> WasiEphemeralNn for WasiNnCtx {
fn load<'b>(
&mut self,
builders: &GraphBuilderArray<'_>,
encoding: GraphEncoding,
target: ExecutionTarget,
target: ExecutionTarget
) -> Result<Graph> {
let encoding_id: u8 = encoding.into();
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
Expand All @@ -41,6 +66,99 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
Ok(graph_id)
}

fn load_by_name<'b>(&mut self, model_name: &GuestPtr<'_,str>) -> Result<Graph> {
let model_name = model_name.as_str().unwrap().unwrap().to_string();
let maybe_loaded_model = self.loaded_models.get(&model_name);

match maybe_loaded_model {
Some(model) => Ok(model.graph),
None => {
let registered_model = self.model_registry.get(&model_name).unwrap();
let model_bytes = &registered_model.model_bytes;
let encoding: GraphEncoding = registered_model.encoding;
let target: ExecutionTarget = registered_model.target;

let encoding_id: u8 = encoding.into();
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
backend.load_from_bytes(model_bytes, target)?
} else {
return Err(UsageError::InvalidEncoding(encoding).into());
};
let graph_id = self.graphs.insert(graph);

Ok(graph_id)
}
}
}

fn register_named_model(
&mut self,
model_name: &GuestPtr<'_, str>,
model_bytes: &GraphBuilderArray<'_>,
encoding: GraphEncoding,
target: ExecutionTarget
) -> Result<()> {
let length: usize = model_bytes.len().try_into().unwrap();
if length > MAX_GUEST_MODEL_REGISTRATION_SIZE {
return Err(
UsageError::ModelTooLarge(length, MAX_GUEST_MODEL_REGISTRATION_SIZE).into(),
);
}

let mut model_bytes_vec: Vec<Vec<u8>> = Vec::with_capacity(length.try_into().unwrap());
let mut model_bytes = model_bytes.as_ptr();
for _ in 0..length {
let v = model_bytes
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)")
.to_vec();
model_bytes_vec.push(v);
model_bytes = model_bytes.add(1)?;
}
let model_name_key = model_name.as_str().unwrap().unwrap().to_string();
match target {
ExecutionTarget::Cpu => {
let graph = self.build_graph(&model_bytes_vec, encoding, target)?;
self.loaded_models
.insert(model_name_key, LoadedModel { graph });
}
_ => {
self.model_registry.insert(
model_name_key,
RegisteredModel {
model_bytes: model_bytes_vec,
encoding,
target
},
);
}
};
Ok(())
}

fn get_model_list<'b>(&mut self,
buffer: &GuestPtr<'b, u8>,
model_list: &GuestPtr<'b, GuestPtr<'b, u8>>,
length: u32) -> Result<()> {
let mut model_names: Vec<String> = self.model_registry.iter().map(|e| e.key().to_string()).collect();
self.loaded_models.iter().for_each(|e| model_names.push(e.key().to_string()));

println!("Model names: {:?}", model_names);
let model_names_array = StringArray { elems: model_names };
model_names_array.write_to_guest(buffer, model_list);
Ok(())
}

fn get_model_list_sizes(&mut self) -> Result<(u32, u32)> {
let mut model_names: Vec<String> = self.model_registry.iter().map(|e| e.key().to_string()).collect();
self.loaded_models.iter().for_each(|e| model_names.push(e.key().to_string()));
let lengths: Vec<u32> = model_names.iter().map(|e| e.len() as u32).collect();
let string_count = lengths.len() as u32;
let buffer_size = lengths.iter().sum::<u32>() as u32 + string_count;
Ok((string_count, buffer_size))
}

fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
graph.init_execution_context()?
Expand Down Expand Up @@ -91,3 +209,52 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
}
}
}

pub struct StringArray {
elems: Vec<String>,
}

impl StringArray {
pub fn new() -> Self {
StringArray { elems: Vec::new() }
}

pub fn number_elements(&self) -> u32 {
self.elems.len() as u32
}

pub fn cumulative_size(&self) -> u32 {
self.elems
.iter()
.map(|e| e.as_bytes().len() + 1)
.sum::<usize>() as u32
}

pub fn write_to_guest<'a>(
&self,
buffer: &GuestPtr<'a, u8>,
element_heads: &GuestPtr<'a, GuestPtr<'a, u8>>,
) -> Result<()> {
println!("Model names to guest: {:?}", self.elems);
let element_heads = element_heads.as_array(self.number_elements());
let buffer = buffer.as_array(self.cumulative_size());
let mut cursor = 0;
for (elem, head) in self.elems.iter().zip(element_heads.iter()) {
let bytes = elem.as_bytes();
let len = bytes.len() as u32;
{
let elem_buffer = buffer
.get_range(cursor..(cursor + len))
.ok_or(UsageError::InvalidContext)?; // Elements don't fit in buffer provided
elem_buffer.copy_from_slice(bytes)?;
}
buffer
.get(cursor + len)
.ok_or(UsageError::InvalidContext)?
.write(0)?; // 0 terminate
head?.write(buffer.get(cursor).expect("already validated"))?;
cursor += len + 1;
}
Ok(())
}
}
31 changes: 30 additions & 1 deletion crates/wasi-nn/src/openvino.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,43 @@ impl Backend for OpenvinoBackend {
.read()?
.as_slice()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
self.load_from_bytes(&vec![xml.to_vec(), weights.to_vec()], target)
}

fn load_from_bytes(
&mut self,
model_bytes: &Vec<Vec<u8>>,
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError> {
if model_bytes.len() != 2 {
return Err(BackendError::InvalidNumberOfBuilders(
2,
model_bytes.len().try_into().unwrap(),
)
.into());
}

// Construct the context if none is present; this is done lazily (i.e.
// upon actually loading a model) because it may fail to find and load
// the OpenVINO libraries. The laziness limits the extent of the error
// only to wasi-nn users, not all WASI users.
if self.0.is_none() {
self.0.replace(openvino::Core::new(None)?);
}

// Read the guest array.
let xml = model_bytes[0].as_slice();
// .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
let weights = model_bytes[1].as_slice();
// .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");

// Construct OpenVINO graph structures: `cnn_network` contains the graph
// structure, `exec_network` can perform inference.
let core = self
.0
.as_mut()
.expect("openvino::Core was previously constructed");
let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?;
let mut cnn_network = core.read_network_from_buffer(xml, weights)?;

// TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run.
// However, without this newer versions of OpenVINO will fail due to parameter mismatch.
Expand Down
2 changes: 1 addition & 1 deletion crates/wasi-nn/src/witx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::Result;

// Generate the traits and types of wasi-nn in several Rust modules (e.g. `types`).
wiggle::from_witx!({
witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"],
witx: ["$WASI_ROOT/wasi-nn.witx"],
errors: { nn_errno => WasiNnError }
});

Expand Down