diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index b0977562638b..393431966917 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -19,6 +19,7 @@ wiggle = { workspace = true } # These dependencies are necessary for the wasi-nn implementation: openvino = { version = "0.5.0", features = ["runtime-linking"] } thiserror = { workspace = true } +dashmap = "5.4.0" [build-dependencies] walkdir = { workspace = true } diff --git a/crates/wasi-nn/spec b/crates/wasi-nn/spec index 8adc5b9b3bb8..d54aa2a51452 160000 --- a/crates/wasi-nn/spec +++ b/crates/wasi-nn/spec @@ -1 +1 @@ -Subproject commit 8adc5b9b3bb8f885d44f55b464718e24af892c94 +Subproject commit d54aa2a514523377f543f6d36d7028213e1077a3 diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/api.rs index 2ad6e0edf94e..83774474d6b0 100644 --- a/crates/wasi-nn/src/api.rs +++ b/crates/wasi-nn/src/api.rs @@ -14,6 +14,12 @@ pub(crate) trait Backend: Send + Sync { builders: &GraphBuilderArray<'_>, target: ExecutionTarget, ) -> Result, BackendError>; + + fn load_from_bytes( + &mut self, + model_bytes: &Vec>, + target: ExecutionTarget, + ) -> Result, BackendError>; } /// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index 988bc27bcb03..7ceec1ce0f7c 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -3,7 +3,8 @@ use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; use crate::openvino::OpenvinoBackend; use crate::r#impl::UsageError; -use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext}; +use crate::witx::types::{ExecutionTarget, Graph, GraphEncoding, GraphExecutionContext}; +use dashmap::DashMap; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; @@ -14,6 +15,18 @@ pub struct WasiNnCtx { pub(crate) backends: HashMap>, pub(crate) graphs: Table>, pub(crate) executions: Table>, + pub(crate) model_registry: DashMap, + pub(crate) loaded_models: DashMap, +} + +pub(crate) struct RegisteredModel { + pub(crate) model_bytes: Vec>, + pub(crate) encoding: GraphEncoding, + pub(crate) target: ExecutionTarget, +} + +pub(crate) struct LoadedModel { + pub(crate) graph: Graph, } impl WasiNnCtx { @@ -30,6 +43,8 @@ impl WasiNnCtx { backends, graphs: Table::default(), executions: Table::default(), + model_registry: DashMap::new(), + loaded_models: DashMap::new(), }) } } diff --git a/crates/wasi-nn/src/impl.rs b/crates/wasi-nn/src/impl.rs index 0f8da5247a7b..10c13f08fd8a 100644 --- a/crates/wasi-nn/src/impl.rs +++ b/crates/wasi-nn/src/impl.rs @@ -1,13 +1,17 @@ //! Implements the wasi-nn API. -use crate::ctx::WasiNnResult as Result; + +use crate::ctx::{LoadedModel, RegisteredModel, WasiNnResult as Result}; use crate::witx::types::{ - ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor, + ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, + Tensor, }; use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn; use crate::WasiNnCtx; use thiserror::Error; use wiggle::GuestPtr; +const MAX_GUEST_MODEL_REGISTRATION_SIZE: usize = 20 * 1024 * 1024; //20M + #[derive(Debug, Error)] pub enum UsageError { #[error("Invalid context; has the load function been called?")] @@ -22,6 +26,27 @@ pub enum UsageError { InvalidExecutionContextHandle, #[error("Not enough memory to copy tensor data of size: {0}")] NotEnoughMemory(u32), + #[error("Model size {0} exceeds allowed quota of {1}")] + ModelTooLarge(usize, usize), +} + +impl WasiNnCtx { + fn build_graph( + &mut self, + model_bytes: &Vec>, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + let encoding_id: u8 = encoding.into(); + let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) { + backend.load_from_bytes(model_bytes, target)? + } else { + return Err(UsageError::InvalidEncoding(encoding).into()); + }; + + let graph_id = self.graphs.insert(graph); + Ok(graph_id) + } } impl<'a> WasiEphemeralNn for WasiNnCtx { @@ -29,7 +54,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx { &mut self, builders: &GraphBuilderArray<'_>, encoding: GraphEncoding, - target: ExecutionTarget, + target: ExecutionTarget ) -> Result { let encoding_id: u8 = encoding.into(); let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) { @@ -41,6 +66,99 @@ impl<'a> WasiEphemeralNn for WasiNnCtx { Ok(graph_id) } + fn load_by_name<'b>(&mut self, model_name: &GuestPtr<'_,str>) -> Result { + let model_name = model_name.as_str().unwrap().unwrap().to_string(); + let maybe_loaded_model = self.loaded_models.get(&model_name); + + match maybe_loaded_model { + Some(model) => Ok(model.graph), + None => { + let registered_model = self.model_registry.get(&model_name).unwrap(); + let model_bytes = ®istered_model.model_bytes; + let encoding: GraphEncoding = registered_model.encoding; + let target: ExecutionTarget = registered_model.target; + + let encoding_id: u8 = encoding.into(); + let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) { + backend.load_from_bytes(model_bytes, target)? + } else { + return Err(UsageError::InvalidEncoding(encoding).into()); + }; + let graph_id = self.graphs.insert(graph); + + Ok(graph_id) + } + } + } + + fn register_named_model( + &mut self, + model_name: &GuestPtr<'_, str>, + model_bytes: &GraphBuilderArray<'_>, + encoding: GraphEncoding, + target: ExecutionTarget + ) -> Result<()> { + let length: usize = model_bytes.len().try_into().unwrap(); + if length > MAX_GUEST_MODEL_REGISTRATION_SIZE { + return Err( + UsageError::ModelTooLarge(length, MAX_GUEST_MODEL_REGISTRATION_SIZE).into(), + ); + } + + let mut model_bytes_vec: Vec> = Vec::with_capacity(length.try_into().unwrap()); + let mut model_bytes = model_bytes.as_ptr(); + for _ in 0..length { + let v = model_bytes + .read()? + .as_slice()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)") + .to_vec(); + model_bytes_vec.push(v); + model_bytes = model_bytes.add(1)?; + } + let model_name_key = model_name.as_str().unwrap().unwrap().to_string(); + match target { + ExecutionTarget::Cpu => { + let graph = self.build_graph(&model_bytes_vec, encoding, target)?; + self.loaded_models + .insert(model_name_key, LoadedModel { graph }); + } + _ => { + self.model_registry.insert( + model_name_key, + RegisteredModel { + model_bytes: model_bytes_vec, + encoding, + target + }, + ); + } + }; + Ok(()) + } + + fn get_model_list<'b>(&mut self, + buffer: &GuestPtr<'b, u8>, + model_list: &GuestPtr<'b, GuestPtr<'b, u8>>, + length: u32) -> Result<()> { + let mut model_names: Vec = self.model_registry.iter().map(|e| e.key().to_string()).collect(); + self.loaded_models.iter().for_each(|e| model_names.push(e.key().to_string())); + + println!("Model names: {:?}", model_names); + let model_names_array = StringArray { elems: model_names }; + model_names_array.write_to_guest(buffer, model_list); + Ok(()) + } + + fn get_model_list_sizes(&mut self) -> Result<(u32, u32)> { + let mut model_names: Vec = self.model_registry.iter().map(|e| e.key().to_string()).collect(); + self.loaded_models.iter().for_each(|e| model_names.push(e.key().to_string())); + let lengths: Vec = model_names.iter().map(|e| e.len() as u32).collect(); + let string_count = lengths.len() as u32; + let buffer_size = lengths.iter().sum::() as u32 + string_count; + Ok((string_count, buffer_size)) + } + fn init_execution_context(&mut self, graph_id: Graph) -> Result { let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { graph.init_execution_context()? @@ -91,3 +209,52 @@ impl<'a> WasiEphemeralNn for WasiNnCtx { } } } + +pub struct StringArray { + elems: Vec, +} + +impl StringArray { + pub fn new() -> Self { + StringArray { elems: Vec::new() } + } + + pub fn number_elements(&self) -> u32 { + self.elems.len() as u32 + } + + pub fn cumulative_size(&self) -> u32 { + self.elems + .iter() + .map(|e| e.as_bytes().len() + 1) + .sum::() as u32 + } + + pub fn write_to_guest<'a>( + &self, + buffer: &GuestPtr<'a, u8>, + element_heads: &GuestPtr<'a, GuestPtr<'a, u8>>, + ) -> Result<()> { + println!("Model names to guest: {:?}", self.elems); + let element_heads = element_heads.as_array(self.number_elements()); + let buffer = buffer.as_array(self.cumulative_size()); + let mut cursor = 0; + for (elem, head) in self.elems.iter().zip(element_heads.iter()) { + let bytes = elem.as_bytes(); + let len = bytes.len() as u32; + { + let elem_buffer = buffer + .get_range(cursor..(cursor + len)) + .ok_or(UsageError::InvalidContext)?; // Elements don't fit in buffer provided + elem_buffer.copy_from_slice(bytes)?; + } + buffer + .get(cursor + len) + .ok_or(UsageError::InvalidContext)? + .write(0)?; // 0 terminate + head?.write(buffer.get(cursor).expect("already validated"))?; + cursor += len + 1; + } + Ok(()) + } +} \ No newline at end of file diff --git a/crates/wasi-nn/src/openvino.rs b/crates/wasi-nn/src/openvino.rs index 9924326369f3..e1348f9403d5 100644 --- a/crates/wasi-nn/src/openvino.rs +++ b/crates/wasi-nn/src/openvino.rs @@ -44,6 +44,35 @@ impl Backend for OpenvinoBackend { .read()? .as_slice()? .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + self.load_from_bytes(&vec![xml.to_vec(), weights.to_vec()], target) + } + + fn load_from_bytes( + &mut self, + model_bytes: &Vec>, + target: ExecutionTarget, + ) -> Result, BackendError> { + if model_bytes.len() != 2 { + return Err(BackendError::InvalidNumberOfBuilders( + 2, + model_bytes.len().try_into().unwrap(), + ) + .into()); + } + + // Construct the context if none is present; this is done lazily (i.e. + // upon actually loading a model) because it may fail to find and load + // the OpenVINO libraries. The laziness limits the extent of the error + // only to wasi-nn users, not all WASI users. + if self.0.is_none() { + self.0.replace(openvino::Core::new(None)?); + } + + // Read the guest array. + let xml = model_bytes[0].as_slice(); + // .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + let weights = model_bytes[1].as_slice(); + // .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); // Construct OpenVINO graph structures: `cnn_network` contains the graph // structure, `exec_network` can perform inference. @@ -51,7 +80,7 @@ impl Backend for OpenvinoBackend { .0 .as_mut() .expect("openvino::Core was previously constructed"); - let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?; + let mut cnn_network = core.read_network_from_buffer(xml, weights)?; // TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run. // However, without this newer versions of OpenVINO will fail due to parameter mismatch. diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index e7c877bd907e..be7ad70834f5 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -5,7 +5,7 @@ use anyhow::Result; // Generate the traits and types of wasi-nn in several Rust modules (e.g. `types`). wiggle::from_witx!({ - witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], + witx: ["$WASI_ROOT/wasi-nn.witx"], errors: { nn_errno => WasiNnError } });