Skip to content

Commit c919981

Browse files
committed
Support Nvidia-Cuda execution provider for wasi-nn onnx backend
1 parent c65ba68 commit c919981

File tree

4 files changed

+129
-8
lines changed

4 files changed

+129
-8
lines changed

crates/wasi-nn/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ openvino = ["dep:openvino"]
7070
onnx = ["dep:ort"]
7171
# Use prebuilt ONNX Runtime binaries from ort.
7272
onnx-download = ["onnx", "ort/download-binaries"]
73+
# CUDA execution provider for NVIDIA GPU support (requires CUDA toolkit)
74+
onnx-cuda = ["onnx", "ort/cuda"]
7375
# WinML is only available on Windows 10 1809 and later.
7476
winml = ["dep:windows"]
7577
# PyTorch is available on all platforms; requires Libtorch to be installed

crates/wasi-nn/examples/classification-component-onnx/README.md

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,83 @@
33
This example demonstrates how to use the `wasi-nn` crate to run a classification using the
44
[ONNX Runtime](https://onnxruntime.ai/) backend from a WebAssembly component.
55

6+
It supports CPU and GPU (Nvidia CUDA) execution targets.
7+
8+
**Note:**
9+
GPU execution target only supports Nvidia CUDA (onnx-cuda) as execution provider (EP) for now.
10+
611
## Build
12+
713
In this directory, run the following command to build the WebAssembly component:
814
```console
915
cargo component build
1016
```
1117

18+
## Running the Example
19+
1220
In the Wasmtime root directory, run the following command to build the Wasmtime CLI and run the WebAssembly component:
21+
22+
### Building Wasmtime
23+
24+
#### For CPU-only execution:
1325
```sh
14-
# build wasmtime with component-model and WASI-NN with ONNX runtime support
1526
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-download
27+
```
28+
29+
#### For GPU (Nvidia CUDA) support:
30+
```sh
31+
cargo build --features component-model,wasi-nn,wasmtime-wasi-nn/onnx-cuda,wasmtime-wasi-nn/onnx-download
32+
```
33+
34+
### Running with Different Execution Targets
35+
36+
The execution target is controlled by passing a single argument to the WASM module.
37+
38+
Arguments:
39+
- No argument or `cpu` - Use CPU execution
40+
- `gpu` or `cuda` - Use GPU/CUDA execution
1641

17-
# run the component with wasmtime
42+
#### CPU Execution (default):
43+
```sh
1844
./target/debug/wasmtime run \
1945
-Snn \
2046
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
2147
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm
2248
```
2349

24-
You should get the following output:
50+
#### GPU (CUDA) Execution:
51+
```sh
52+
# path to `libonnxruntime_providers_cuda.so` downloaded by `ort-sys`
53+
export LD_LIBRARY_PATH={wasmtime_workspace}/target/debug
54+
55+
./target/debug/wasmtime run \
56+
-Snn \
57+
--dir ./crates/wasi-nn/examples/classification-component-onnx/fixture/::fixture \
58+
./crates/wasi-nn/examples/classification-component-onnx/target/wasm32-wasip1/debug/classification-component-onnx.wasm \
59+
gpu
60+
61+
```
62+
63+
## Expected Output
64+
65+
You should get output similar to:
2566
```txt
67+
No execution target specified, defaulting to CPU
2668
Read ONNX model, size in bytes: 4956208
27-
Loaded graph into wasi-nn
69+
Loaded graph into wasi-nn with Cpu target
2870
Created wasi-nn execution context.
2971
Read ONNX Labels, # of labels: 1000
30-
Set input tensor
3172
Executed graph inference
32-
Getting inferencing output
3373
Retrieved output data with length: 4000
3474
Index: n02099601 golden retriever - Probability: 0.9948673
3575
Index: n02088094 Afghan hound, Afghan - Probability: 0.002528982
3676
Index: n02102318 cocker spaniel, English cocker spaniel, cocker - Probability: 0.0010986356
3777
```
78+
79+
When using GPU target, the first line will indicate the selected execution target.
80+
You can monitor GPU usage using cmd `watch -n 1 nvidia-smi`.
81+
82+
## Prerequisites for GPU(CUDA) Support
83+
- NVIDIA GPU with CUDA support
84+
- CUDA Toolkit 12.x with cuDNN 9.x
85+
- Build wasmtime with `wasmtime-wasi-nn/onnx-cuda` feature

crates/wasi-nn/examples/classification-component-onnx/src/main.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,46 @@ use self::wasi::nn::{
1717
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
1818
};
1919

20+
/// Determine execution target from command-line argument
21+
/// Usage: wasm_module [cpu|gpu|cuda]
22+
fn get_execution_target() -> ExecutionTarget {
23+
let args: Vec<String> = std::env::args().collect();
24+
25+
// First argument (index 0) is the program name, second (index 1) is the target
26+
// Ignore any arguments after index 1
27+
if args.len() >= 2 {
28+
match args[1].to_lowercase().as_str() {
29+
"gpu" | "cuda" => {
30+
println!("Using GPU (CUDA) execution target from argument");
31+
return ExecutionTarget::Gpu;
32+
}
33+
"cpu" => {
34+
println!("Using CPU execution target from argument");
35+
return ExecutionTarget::Cpu;
36+
}
37+
_ => {
38+
println!("Unknown execution target '{}', defaulting to CPU", args[1]);
39+
}
40+
}
41+
} else {
42+
println!("No execution target specified, defaulting to CPU");
43+
println!("Usage: <program> [cpu|gpu|cuda]");
44+
}
45+
46+
ExecutionTarget::Cpu
47+
}
48+
2049
fn main() {
2150
// Load the ONNX model - SqueezeNet 1.1-7
2251
// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet
2352
let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();
2453
println!("Read ONNX model, size in bytes: {}", model.len());
2554

26-
let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap();
27-
println!("Loaded graph into wasi-nn");
55+
// Determine execution target
56+
let execution_target = get_execution_target();
57+
58+
let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap();
59+
println!("Loaded graph into wasi-nn with {:?} target", execution_target);
2860

2961
let exec_context = Graph::init_execution_context(&graph).unwrap();
3062
println!("Created wasi-nn execution context.");

crates/wasi-nn/src/backend/onnx.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@ use crate::backend::{Id, read};
77
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
88
use crate::{ExecutionContext, Graph};
99
use ort::{
10+
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
1011
inputs,
1112
session::{Input, Output},
1213
session::{Session, SessionInputValue, builder::GraphOptimizationLevel},
1314
tensor::TensorElementType,
1415
value::{Tensor as OrtTensor, ValueType},
1516
};
17+
18+
#[cfg(feature = "onnx-cuda")]
19+
use ort::execution_providers::CUDAExecutionProvider;
20+
1621
use std::path::Path;
1722
use std::sync::{Arc, Mutex};
1823

@@ -31,7 +36,11 @@ impl BackendInner for OnnxBackend {
3136
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
3237
}
3338

39+
// Configure execution providers based on target
40+
let execution_providers = configure_execution_providers(target)?;
41+
3442
let session = Session::builder()?
43+
.with_execution_providers(execution_providers)?
3544
.with_optimization_level(GraphOptimizationLevel::Level3)?
3645
.commit_from_memory(builders[0])?;
3746

@@ -45,6 +54,36 @@ impl BackendInner for OnnxBackend {
4554
}
4655
}
4756

57+
/// Configure execution providers based on the target
58+
fn configure_execution_providers(
59+
target: ExecutionTarget,
60+
) -> Result<Vec<ExecutionProviderDispatch>, BackendError> {
61+
match target {
62+
ExecutionTarget::Cpu => {
63+
// Use CPU execution provider with default configuration
64+
tracing::debug!("Using CPU execution provider");
65+
Ok(vec![CPUExecutionProvider::default().build()])
66+
}
67+
ExecutionTarget::Gpu => {
68+
#[cfg(feature = "onnx-cuda")]
69+
{
70+
// Use CUDA execution provider for GPU acceleration
71+
tracing::debug!("Configuring ONNX Nvidia CUDA execution provider for GPU target");
72+
Ok(vec![CUDAExecutionProvider::default().build()])
73+
}
74+
#[cfg(not(feature = "onnx-cuda"))]
75+
{
76+
Err(BackendError::BackendAccess(wasmtime::format_err!(
77+
"GPU execution target is requested, but 'onnx-cuda' feature is not enabled"
78+
)))
79+
}
80+
}
81+
ExecutionTarget::Tpu => {
82+
unimplemented!("TPU execution target is not supported for ONNX backend yet");
83+
}
84+
}
85+
}
86+
4887
impl BackendFromDir for OnnxBackend {
4988
fn load_from_dir(
5089
&mut self,

0 commit comments

Comments
 (0)