Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ openvino = ["dep:openvino"]
onnx = ["dep:ort"]
# Use prebuilt ONNX Runtime binaries from ort.
onnx-download = ["onnx", "ort/download-binaries"]
# CUDA execution provider for NVIDIA GPU support (requires CUDA toolkit)
onnx-cuda = ["onnx", "ort/cuda"]
# WinML is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
# PyTorch is available on all platforms; requires Libtorch to be installed
Expand Down
60 changes: 54 additions & 6 deletions crates/wasi-nn/examples/classification-component-onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,83 @@
This example demonstrates how to use the `wasi-nn` crate to run a classification using the
[ONNX Runtime](https://onnxruntime.ai/) backend from a WebAssembly component.

It supports CPU and GPU (Nvidia CUDA) execution targets.

**Note:**
GPU execution target only supports Nvidia CUDA (onnx-cuda) as execution provider (EP) for now.

## Build

In this directory, run the following command to build the WebAssembly component:
```console
cargo component build
```

## Running the Example

In the Wasmtime root directory, run the following command to build the Wasmtime CLI and run the WebAssembly component:

### Building Wasmtime

#### For CPU-only execution:
```sh
# build wasmtime with component-model and WASI-NN with ONNX runtime support
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-download
```

#### For GPU (Nvidia CUDA) support:
```sh
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/onnx-download
```

### Running with Different Execution Targets

The execution target is controlled by passing a single argument to the WASM module.

Arguments:
- No argument or `cpu` - Use CPU execution
- `gpu` or `cuda` - Use GPU/CUDA execution

# run the component with wasmtime
#### CPU Execution (default):
```sh
./target/debug/wasmtime run \
-Snn \
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm
```

You should get the following output:
#### GPU (CUDA) Execution:
```sh
# path to `libonnxruntime_providers_cuda.so` downloaded by `ort-sys`
export LD_LIBRARY_PATH={wasmtime_workspace}/target/debug

./target/debug/wasmtime run \
-Snn \
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm \
gpu

```

## Expected Output

You should get output similar to:
```txt
No execution target specified, defaulting to CPU
Read ONNX model, size in bytes: 4956208
Loaded graph into wasi-nn
Loaded graph into wasi-nn with Cpu target
Created wasi-nn execution context.
Read ONNX Labels, # of labels: 1000
Set input tensor
Executed graph inference
Getting inferencing output
Retrieved output data with length: 4000
Index: n02099601 golden retriever - Probability: 0.9948673
Index: n02088094 Afghan hound, Afghan - Probability: 0.002528982
Index: n02102318 cocker spaniel, English cocker spaniel, cocker - Probability: 0.0010986356
```

When using GPU target, the first line will indicate the selected execution target.
You can monitor GPU usage using cmd `watch -n 1 nvidia-smi`.

## Prerequisites for GPU(CUDA) Support
- NVIDIA GPU with CUDA support
- CUDA Toolkit 12.x with cuDNN 9.x
- Build wasmtime with `wasmtime-wasi-nn/onnx-cuda` feature
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,46 @@ use self::wasi::nn::{
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
};

/// Determine execution target from command-line argument
/// Usage: wasm_module [cpu|gpu|cuda]
fn get_execution_target() -> ExecutionTarget {
let args: Vec<String> = std::env::args().collect();

// First argument (index 0) is the program name, second (index 1) is the target
// Ignore any arguments after index 1
if args.len() >= 2 {
match args[1].to_lowercase().as_str() {
"gpu" | "cuda" => {
println!("Using GPU (CUDA) execution target from argument");
return ExecutionTarget::Gpu;
}
"cpu" => {
println!("Using CPU execution target from argument");
return ExecutionTarget::Cpu;
}
_ => {
println!("Unknown execution target '{}', defaulting to CPU", args[1]);
}
}
} else {
println!("No execution target specified, defaulting to CPU");
println!("Usage: <program> [cpu|gpu|cuda]");
}

ExecutionTarget::Cpu
}

fn main() {
// Load the ONNX model - SqueezeNet 1.1-7
// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet
let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();
println!("Read ONNX model, size in bytes: {}", model.len());

let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap();
println!("Loaded graph into wasi-nn");
// Determine execution target
let execution_target = get_execution_target();

let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap();
println!("Loaded graph into wasi-nn with {:?} target", execution_target);

let exec_context = Graph::init_execution_context(&graph).unwrap();
println!("Created wasi-nn execution context.");
Expand Down
39 changes: 39 additions & 0 deletions crates/wasi-nn/src/backend/onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ use crate::backend::{Id, read};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use ort::{
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
inputs,
session::{Input, Output},
session::{Session, SessionInputValue, builder::GraphOptimizationLevel},
tensor::TensorElementType,
value::{Tensor as OrtTensor, ValueType},
};

#[cfg(feature = "onnx-cuda")]
use ort::execution_providers::CUDAExecutionProvider;

use std::path::Path;
use std::sync::{Arc, Mutex};

Expand All @@ -31,7 +36,11 @@ impl BackendInner for OnnxBackend {
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
}

// Configure execution providers based on target
let execution_providers = configure_execution_providers(target)?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_memory(builders[0])?;

Expand All @@ -45,6 +54,36 @@ impl BackendInner for OnnxBackend {
}
}

/// Configure execution providers based on the target
fn configure_execution_providers(
target: ExecutionTarget,
) -> Result<Vec<ExecutionProviderDispatch>, BackendError> {
match target {
ExecutionTarget::Cpu => {
// Use CPU execution provider with default configuration
tracing::debug!("Using CPU execution provider");
Ok(vec![CPUExecutionProvider::default().build()])
}
ExecutionTarget::Gpu => {
#[cfg(feature = "onnx-cuda")]
{
// Use CUDA execution provider for GPU acceleration
tracing::debug!("Configuring ONNX Nvidia CUDA execution provider for GPU target");
Ok(vec![CUDAExecutionProvider::default().build()])
}
#[cfg(not(feature = "onnx-cuda"))]
{
Err(BackendError::BackendAccess(wasmtime::format_err!(
"GPU execution target is requested, but 'onnx-cuda' feature is not enabled"
)))
}
}
ExecutionTarget::Tpu => {
unimplemented!("TPU execution target is not supported for ONNX backend yet");
}
}
}

impl BackendFromDir for OnnxBackend {
fn load_from_dir(
&mut self,
Expand Down