diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index 7bc7e0f75bc5..dc06bbd8e631 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -70,6 +70,8 @@ openvino = ["dep:openvino"] onnx = ["dep:ort"] # Use prebuilt ONNX Runtime binaries from ort. onnx-download = ["onnx", "ort/download-binaries"] +# CUDA execution provider for NVIDIA GPU support (requires CUDA toolkit) +onnx-cuda = ["onnx", "ort/cuda"] # WinML is only available on Windows 10 1809 and later. winml = ["dep:windows"] # PyTorch is available on all platforms; requires Libtorch to be installed diff --git a/crates/wasi-nn/examples/classification-component-onnx/README.md b/crates/wasi-nn/examples/classification-component-onnx/README.md index 9105aa96793f..76e93e6ff3f5 100644 --- a/crates/wasi-nn/examples/classification-component-onnx/README.md +++ b/crates/wasi-nn/examples/classification-component-onnx/README.md @@ -3,35 +3,83 @@ This example demonstrates how to use the `wasi-nn` crate to run a classification using the [ONNX Runtime](https://onnxruntime.ai/) backend from a WebAssembly component. +It supports CPU and GPU (Nvidia CUDA) execution targets. + +**Note:** +GPU execution target only supports Nvidia CUDA (onnx-cuda) as execution provider (EP) for now. + ## Build + In this directory, run the following command to build the WebAssembly component: ```console cargo component build ``` +## Running the Example + In the Wasmtime root directory, run the following command to build the Wasmtime CLI and run the WebAssembly component: + +### Building Wasmtime + +#### For CPU-only execution: ```sh -# build wasmtime with component-model and WASI-NN with ONNX runtime support cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-download +``` + +#### For GPU (Nvidia CUDA) support: +```sh +cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/onnx-download +``` + +### Running with Different Execution Targets + +The execution target is controlled by passing a single argument to the WASM module. + +Arguments: +- No argument or `cpu` - Use CPU execution +- `gpu` or `cuda` - Use GPU/CUDA execution -# run the component with wasmtime +#### CPU Execution (default): +```sh ./target/debug/wasmtime run \ -Snn \ --dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \ ./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm ``` -You should get the following output: +#### GPU (CUDA) Execution: +```sh +# path to `libonnxruntime_providers_cuda.so` downloaded by `ort-sys` +export LD_LIBRARY_PATH={wasmtime_workspace}/target/debug + +./target/debug/wasmtime run \ + -Snn \ + --dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \ + ./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm \ + gpu + +``` + +## Expected Output + +You should get output similar to: ```txt +No execution target specified, defaulting to CPU Read ONNX model, size in bytes: 4956208 -Loaded graph into wasi-nn +Loaded graph into wasi-nn with Cpu target Created wasi-nn execution context. Read ONNX Labels, # of labels: 1000 -Set input tensor Executed graph inference -Getting inferencing output Retrieved output data with length: 4000 Index: n02099601 golden retriever - Probability: 0.9948673 Index: n02088094 Afghan hound, Afghan - Probability: 0.002528982 Index: n02102318 cocker spaniel, English cocker spaniel, cocker - Probability: 0.0010986356 ``` + +When using GPU target, the first line will indicate the selected execution target. +You can monitor GPU usage using cmd `watch -n 1 nvidia-smi`. + +## Prerequisites for GPU(CUDA) Support +- NVIDIA GPU with CUDA support +- CUDA Toolkit 12.x with cuDNN 9.x +- Build wasmtime with `wasmtime-wasi-nn/onnx-cuda` feature diff --git a/crates/wasi-nn/examples/classification-component-onnx/src/main.rs b/crates/wasi-nn/examples/classification-component-onnx/src/main.rs index c02fc1ed8da2..affa61681557 100644 --- a/crates/wasi-nn/examples/classification-component-onnx/src/main.rs +++ b/crates/wasi-nn/examples/classification-component-onnx/src/main.rs @@ -17,14 +17,46 @@ use self::wasi::nn::{ tensor::{Tensor, TensorData, TensorDimensions, TensorType}, }; +/// Determine execution target from command-line argument +/// Usage: wasm_module [cpu|gpu|cuda] +fn get_execution_target() -> ExecutionTarget { + let args: Vec = std::env::args().collect(); + + // First argument (index 0) is the program name, second (index 1) is the target + // Ignore any arguments after index 1 + if args.len() >= 2 { + match args[1].to_lowercase().as_str() { + "gpu" | "cuda" => { + println!("Using GPU (CUDA) execution target from argument"); + return ExecutionTarget::Gpu; + } + "cpu" => { + println!("Using CPU execution target from argument"); + return ExecutionTarget::Cpu; + } + _ => { + println!("Unknown execution target '{}', defaulting to CPU", args[1]); + } + } + } else { + println!("No execution target specified, defaulting to CPU"); + println!("Usage: [cpu|gpu|cuda]"); + } + + ExecutionTarget::Cpu +} + fn main() { // Load the ONNX model - SqueezeNet 1.1-7 // Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap(); println!("Read ONNX model, size in bytes: {}", model.len()); - let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap(); - println!("Loaded graph into wasi-nn"); + // Determine execution target + let execution_target = get_execution_target(); + + let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap(); + println!("Loaded graph into wasi-nn with {:?} target", execution_target); let exec_context = Graph::init_execution_context(&graph).unwrap(); println!("Created wasi-nn execution context."); diff --git a/crates/wasi-nn/src/backend/onnx.rs b/crates/wasi-nn/src/backend/onnx.rs index 7aaa6bf4ecec..2cbcf456fa58 100644 --- a/crates/wasi-nn/src/backend/onnx.rs +++ b/crates/wasi-nn/src/backend/onnx.rs @@ -7,12 +7,17 @@ use crate::backend::{Id, read}; use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use ort::{ + execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch}, inputs, session::{Input, Output}, session::{Session, SessionInputValue, builder::GraphOptimizationLevel}, tensor::TensorElementType, value::{Tensor as OrtTensor, ValueType}, }; + +#[cfg(feature = "onnx-cuda")] +use ort::execution_providers::CUDAExecutionProvider; + use std::path::Path; use std::sync::{Arc, Mutex}; @@ -31,7 +36,11 @@ impl BackendInner for OnnxBackend { return Err(BackendError::InvalidNumberOfBuilders(1, builders.len())); } + // Configure execution providers based on target + let execution_providers = configure_execution_providers(target)?; + let session = Session::builder()? + .with_execution_providers(execution_providers)? .with_optimization_level(GraphOptimizationLevel::Level3)? .commit_from_memory(builders[0])?; @@ -45,6 +54,36 @@ impl BackendInner for OnnxBackend { } } +/// Configure execution providers based on the target +fn configure_execution_providers( + target: ExecutionTarget, +) -> Result, BackendError> { + match target { + ExecutionTarget::Cpu => { + // Use CPU execution provider with default configuration + tracing::debug!("Using CPU execution provider"); + Ok(vec![CPUExecutionProvider::default().build()]) + } + ExecutionTarget::Gpu => { + #[cfg(feature = "onnx-cuda")] + { + // Use CUDA execution provider for GPU acceleration + tracing::debug!("Configuring ONNX Nvidia CUDA execution provider for GPU target"); + Ok(vec![CUDAExecutionProvider::default().build()]) + } + #[cfg(not(feature = "onnx-cuda"))] + { + Err(BackendError::BackendAccess(wasmtime::format_err!( + "GPU execution target is requested, but 'onnx-cuda' feature is not enabled" + ))) + } + } + ExecutionTarget::Tpu => { + unimplemented!("TPU execution target is not supported for ONNX backend yet"); + } + } +} + impl BackendFromDir for OnnxBackend { fn load_from_dir( &mut self,