Skip to content

Commit 69d4913

Browse files
author
haixuanTao
committed
Add feature cuda for better ergonomics
1 parent 68036f2 commit 69d4913

File tree

5 files changed

+31
-36
lines changed

5 files changed

+31
-36
lines changed

README.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ This is an attempt at a Rust wrapper for
1111

1212
This project consist on two crates:
1313

14-
* [`onnxruntime-sys`](onnxruntime-sys): Low-level binding to the C API;
15-
* [`onnxruntime`](onnxruntime): High-level and safe API.
14+
- [`onnxruntime-sys`](onnxruntime-sys): Low-level binding to the C API;
15+
- [`onnxruntime`](onnxruntime): High-level and safe API.
1616

1717
[Changelog](CHANGELOG.md)
1818

@@ -21,25 +21,25 @@ which provides the following targets:
2121

2222
CPU:
2323

24-
* Linux x86_64
25-
* macOS x86_64
26-
* macOS aarch64 (no pre-built binaries, no CI testing, see [#74](https://github.com/nbigaouette/onnxruntime-rs/pull/74))
27-
* Windows i686
28-
* Windows x86_64
24+
- Linux x86_64
25+
- macOS x86_64
26+
- macOS aarch64 (no pre-built binaries, no CI testing, see [#74](https://github.com/nbigaouette/onnxruntime-rs/pull/74))
27+
- Windows i686
28+
- Windows x86_64
2929

3030
GPU:
3131

32-
* Linux x86_64
33-
* Windows x86_64
32+
- Linux x86_64
33+
- Windows x86_64
3434

3535
---
3636

3737
**WARNING**:
3838

39-
* This is an experiment and work in progress; it is _not_ complete/working/safe. Help welcome!
40-
* Basic inference works, see [`onnxruntime/examples/sample.rs`](onnxruntime/examples/sample.rs) or [`onnxruntime/tests/integration_tests.rs`](onnxruntime/tests/integration_tests.rs)
41-
* ONNX Runtime has many options to control the inference process but those options are not yet exposed.
42-
* This was developed and tested on macOS Catalina. Other platforms should work but have not been tested.
39+
- This is an experiment and work in progress; it is _not_ complete/working/safe. Help welcome!
40+
- Basic inference works, see [`onnxruntime/examples/sample.rs`](onnxruntime/examples/sample.rs) or [`onnxruntime/tests/integration_tests.rs`](onnxruntime/tests/integration_tests.rs)
41+
- ONNX Runtime has many options to control the inference process but those options are not yet exposed.
42+
- This was developed and tested on macOS Catalina. Other platforms should work but have not been tested.
4343

4444
---
4545

@@ -58,14 +58,14 @@ To select which strategy to use, set the `ORT_STRATEGY` environment variable to:
5858
3. `compile`: To compile the library
5959

6060
The `download` strategy supports downloading a version of ONNX that supports CUDA. To use this, set the
61-
environment variable `ORT_USE_CUDA=1` (only supports Linux or Windows).
61+
feature `cuda` in `Cargo.toml`.
6262

6363
Until the build script allow compilation of the runtime, see the [compilation notes](ONNX_Compilation_Notes.md)
6464
for some details on the process.
6565

6666
### Note on using CUDA
6767

68-
To use CUDA you will need to set `ORT_USE_CUDA=1` but also to set your session with the method `use_cuda` as such:
68+
To use CUDA you will need to set the feature `cuda` but also to set your session with the method `use_cuda` as such:
6969

7070
```
7171
let mut session = environment
@@ -86,9 +86,9 @@ dyld: Library not loaded: @rpath/libonnxruntime.1.7.1.dylib
8686

8787
To fix, one can either:
8888

89-
* Set the `LD_LIBRARY_PATH` environment variable to point to the path where the library can be found.
90-
* Adapt the `.cargo/config` file to contain a linker flag to provide the **full** path:
91-
89+
- Set the `LD_LIBRARY_PATH` environment variable to point to the path where the library can be found.
90+
- Adapt the `.cargo/config` file to contain a linker flag to provide the **full** path:
91+
9292
```toml
9393
[target.aarch64-apple-darwin]
9494
rustflags = ["-C", "link-args=-Wl,-rpath,/full/path/to/onnxruntime/lib"]
@@ -269,9 +269,9 @@ instead of the Rust moderation team.
269269

270270
This project is licensed under either of
271271

272-
* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or
272+
- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or
273273
http://www.apache.org/licenses/LICENSE-2.0)
274-
* MIT license ([LICENSE-MIT](LICENSE-MIT) or
274+
- MIT license ([LICENSE-MIT](LICENSE-MIT) or
275275
http://opensource.org/licenses/MIT)
276276

277277
at your option.

onnxruntime-sys/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ default = []
3434
disable-sys-build-script = []
3535
# Use bindgen to generate bindings in build.rs
3636
generate-bindings = ["bindgen"]
37+
cuda = []
3738

3839
[package.metadata.docs.rs]
3940
# Disable the build.rs on https://docs.rs since it can cause

onnxruntime-sys/build.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ const ORT_ENV_STRATEGY: &str = "ORT_STRATEGY";
2828
/// Name of environment variable that, if present, contains the location of a pre-built library.
2929
/// Only used if `ORT_STRATEGY=system`.
3030
const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION";
31-
/// Name of environment variable that, if present, controls wether to use CUDA or not.
32-
const ORT_ENV_GPU: &str = "ORT_USE_CUDA";
3331

3432
/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library.
3533
const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
@@ -54,7 +52,6 @@ fn main() {
5452
println!("cargo:rustc-link-search=native={}", lib_dir.display());
5553

5654
println!("cargo:rerun-if-env-changed={}", ORT_ENV_STRATEGY);
57-
println!("cargo:rerun-if-env-changed={}", ORT_ENV_GPU);
5855
println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION);
5956

6057
generate_bindings(&include_dir);
@@ -280,17 +277,6 @@ enum Accelerator {
280277
Gpu,
281278
}
282279

283-
impl FromStr for Accelerator {
284-
type Err = String;
285-
286-
fn from_str(s: &str) -> Result<Self, Self::Err> {
287-
match s.to_lowercase().as_str() {
288-
"1" | "yes" | "true" | "on" => Ok(Accelerator::Gpu),
289-
_ => Ok(Accelerator::None),
290-
}
291-
}
292-
}
293-
294280
impl OnnxPrebuiltArchive for Accelerator {
295281
fn as_onnx_str(&self) -> Cow<str> {
296282
match self {
@@ -353,6 +339,12 @@ impl OnnxPrebuiltArchive for Triplet {
353339
}
354340

355341
fn prebuilt_archive_url() -> (PathBuf, String) {
342+
let accelerator = if cfg!(feature = "cuda") {
343+
Accelerator::Gpu
344+
} else {
345+
Accelerator::None
346+
};
347+
356348
let triplet = Triplet {
357349
os: env::var("CARGO_CFG_TARGET_OS")
358350
.expect("Unable to get TARGET_OS")
@@ -362,7 +354,7 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
362354
.expect("Unable to get TARGET_ARCH")
363355
.parse()
364356
.unwrap(),
365-
accelerator: env::var(ORT_ENV_GPU).unwrap_or_default().parse().unwrap(),
357+
accelerator,
366358
};
367359

368360
let prebuilt_archive = format!(

onnxruntime/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ model-fetching = ["ureq"]
4141
# Disable build script; used for https://docs.rs
4242
disable-sys-build-script = ["onnxruntime-sys/disable-sys-build-script"]
4343
generate-bindings = ["onnxruntime-sys/generate-bindings"]
44+
cuda = ["onnxruntime-sys/cuda"]
4445

4546
[package.metadata.docs.rs]
4647
features = ["disable-sys-build-script", "model-fetching"]

onnxruntime/src/session.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ impl<'a> SessionBuilder<'a> {
130130
Ok(self)
131131
}
132132

133-
/// Set the session to use cuda
133+
/// Set the session to use cpu
134134
pub fn use_cpu(self, use_arena: i32) -> Result<SessionBuilder<'a>> {
135135
unsafe {
136136
sys::OrtSessionOptionsAppendExecutionProvider_CPU(self.session_options_ptr, use_arena);
137137
}
138138
Ok(self)
139139
}
140140

141+
#[cfg(feature = "cuda")]
141142
/// Set the session to use cuda
142143
pub fn use_cuda(self, device_id: i32) -> Result<SessionBuilder<'a>> {
143144
unsafe {

0 commit comments

Comments
 (0)