diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 1589555a..e0c46f2a 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -66,7 +66,18 @@ fn generate_bindings(_include_dir: &Path) { #[cfg(feature = "generate-bindings")] fn generate_bindings(include_dir: &Path) { + let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); let clang_arg = format!("-I{}", include_dir.display()); + let clang_cuda_arg = match env::var(ORT_ENV_GPU) { + Ok(cuda_env) => match cuda_env.to_lowercase().as_str() { + "1" | "yes" | "true" | "on" => match os.as_str() { + "linux" | "windows" => "-DORT_USE_CUDA", + _ => "", + }, + _ => "", + }, + Err(_) => "", + }; // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); @@ -81,6 +92,8 @@ fn generate_bindings(include_dir: &Path) { .header("wrapper.h") // The current working directory is 'onnxruntime-sys' .clang_arg(clang_arg) + // Add define ORT_USE_CUDA + .clang_arg(clang_cuda_arg) // Tell cargo to invalidate the built crate whenever any of the // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) diff --git a/onnxruntime-sys/wrapper.h b/onnxruntime-sys/wrapper.h index e63d3523..b11f5275 100644 --- a/onnxruntime-sys/wrapper.h +++ b/onnxruntime-sys/wrapper.h @@ -1 +1,5 @@ #include "onnxruntime_c_api.h" +#include "cpu_provider_factory.h" +#ifdef ORT_USE_CUDA +#include "cuda_provider_factory.h" +#endif diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 9ceec820..13d03a2b 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -36,6 +36,7 @@ tracing-subscriber = "0.2" ureq = "1.5.1" [features] +cuda = [] # Fetch model from ONNX Model Zoo (https://github.com/onnx/models) model-fetching = ["ureq"] # Disable build script; used for https://docs.rs diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index d16d08da..26543142 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -25,6 +25,15 @@ fn run() -> Result<(), Error> { tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + #[cfg(feature = "cuda")] + let environment = Environment::builder() + .with_name("test") + .with_gpu(0) + // The ONNX Runtime's log level can be different than the one of the wrapper crate or the application. + .with_log_level(LoggingLevel::Info) + .build()?; + + #[cfg(not(feature = "cuda"))] let environment = Environment::builder() .with_name("test") // The ONNX Runtime's log level can be different than the one of the wrapper crate or the application. diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..e22b7ee0 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -124,6 +124,29 @@ impl<'a> SessionBuilder<'a> { Ok(self) } + /// Set the session use cpu provider + pub fn with_cpu(self, use_arena: bool) -> Result> { + unsafe { + sys::OrtSessionOptionsAppendExecutionProvider_CPU( + self.session_options_ptr, + use_arena.into(), + ); + } + Ok(self) + } + + /// Set the session use cuda provider + #[cfg(feature = "cuda")] + pub fn with_cuda(self, device_id: i32) -> Result> { + unsafe { + sys::OrtSessionOptionsAppendExecutionProvider_CUDA( + self.session_options_ptr, + device_id.into(), + ); + } + Ok(self) + } + /// Set the session's allocator /// /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena)