Skip to content

Commit b70c7d1

Browse files
authored
Merge pull request #33 from nbigaouette/windows
Work on cross-platform, mostly Windows
2 parents 2195648 + 71619c7 commit b70c7d1

File tree

17 files changed

+17478
-231
lines changed

17 files changed

+17478
-231
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Update ONNX Runtime to 1.5.2 from 1.4.0 ([#30](https://github.com/nbigaouette/onnxruntime-rs/pull/30))
1313
- Refactor feature flags and how bindings are generated ([#31](https://github.com/nbigaouette/onnxruntime-rs/pull/31))
14+
- Refactor build script for better cross-platform support, including Windows support ([#33](https://github.com/nbigaouette/onnxruntime-rs/pull/33))
1415

1516
## [0.0.9] - 2020-10-13
1617

README.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,52 @@ that performs simple model download and inference, validating the results.
162162
Bindings (the basis of `onnxruntime-sys`) are committed to the git repository. This means `bindgen` is not
163163
a dependency anymore on every build (it was made optional) and thus build times are better.
164164

165-
To generate new bindings (for example if they don't exists for your platform or if a version bump occurred), run the
166-
following on all platforms and commit the changes:
165+
To generate new bindings (for example if they don't exists for your platform or if a version bump occurred), build the crate with the `generate-bindings` feature.
166+
167+
NOTE: Make sure to have the `rustfmt` rustup component present so that bindings are formatted:
168+
169+
```sh
170+
rustup component add rustfmt
171+
```
172+
173+
Then on each platform build with the proper feature flag:
174+
175+
```sh
176+
cd onnxruntime-sys
177+
❯ cargo build --features generate-bindings
178+
```
179+
180+
### Generating Bindings for Linux With Docker
181+
182+
Prepare the container:
183+
184+
```sh
185+
❯ docker run -it --rm --name rustbuilder -v "$PWD":/usr/src/myapp -w /usr/src/myapp rust:1.47.0 /bin/bash
186+
❯ apt-get update
187+
❯ apt-get install clang
188+
❯ rustup component add rustfmt
189+
```
190+
191+
Generate the bindings:
192+
193+
```sh
194+
❯ docker exec -it --user "$(id -u)":"$(id -g)" rustbuilder /bin/bash
195+
cd onnxruntime-sys
196+
❯ cargo build --features generate-bindings
197+
```
198+
199+
### Generating Bindings for Windows With Vagrant
200+
201+
You can use [nbigaouette/windows_vagrant_rust](https://github.com/nbigaouette/windows_vagrant_rust)
202+
to provision a Windows VM that can build the project and generate the bindings.
203+
204+
Windows can build both x86 and x86_64 bindings:
167205

168206
```sh
169-
cargo build --package onnxruntime-sys --features generate-bindings
207+
❯ rustup target add i686-pc-windows-msvc x86_64-pc-windows-msvc
208+
cd onnxruntime-sys
209+
❯ cargo build --features generate-bindings --target i686-pc-windows-msvc
210+
❯ cargo build --features generate-bindings --target x86_64-pc-windows-msvc
170211
```
171212

172213
## Conduct

onnxruntime-sys/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ keywords = ["neuralnetworks", "onnx", "bindings"]
2020
bindgen = {version = "0.55", optional = true}
2121
ureq = "1.5.1"
2222

23-
[target.'cfg(windows)'.build-dependencies]
23+
# Used on Windows
2424
zip = "0.5"
2525

26-
[target.'cfg(unix)'.build-dependencies]
26+
# Used on unix
2727
flate2 = "1.0"
2828
tar = "0.4"
2929

onnxruntime-sys/build.rs

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#![allow(dead_code)]
22

3-
use io::Write;
43
use std::{
54
env, fs,
6-
io::{self, BufWriter, Read},
5+
io::{self, Read, Write},
76
path::{Path, PathBuf},
87
};
98

@@ -36,8 +35,6 @@ const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
3635
#[cfg(feature = "disable-sys-build-script")]
3736
fn main() {
3837
println!("Build script disabled!");
39-
40-
generate_file_including_platform_bindings().unwrap();
4138
}
4239

4340
#[cfg(not(feature = "disable-sys-build-script"))]
@@ -59,8 +56,6 @@ fn main() {
5956
println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION);
6057

6158
generate_bindings(&include_dir);
62-
63-
generate_file_including_platform_bindings().unwrap();
6459
}
6560

6661
#[cfg(not(feature = "generate-bindings"))]
@@ -75,6 +70,7 @@ fn generate_bindings(include_dir: &Path) {
7570

7671
// Tell cargo to invalidate the built crate whenever the wrapper changes
7772
println!("cargo:rerun-if-changed=wrapper.h");
73+
println!("cargo:rerun-if-changed=src/generated/bindings.rs");
7874

7975
// The bindgen::Builder is the main entry point
8076
// to bindgen, and lets you build up options for
@@ -88,6 +84,9 @@ fn generate_bindings(include_dir: &Path) {
8884
// Tell cargo to invalidate the built crate whenever any of the
8985
// included header files changed.
9086
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
87+
// Format using rustfmt
88+
.rustfmt_bindings(true)
89+
.rustified_enum("*")
9190
// Finish the builder and generate the bindings.
9291
.generate()
9392
// Unwrap the Result and panic on failure.
@@ -100,48 +99,12 @@ fn generate_bindings(include_dir: &Path) {
10099
.join(env::var("CARGO_CFG_TARGET_OS").unwrap())
101100
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap())
102101
.join("bindings.rs");
102+
println!("cargo:rerun-if-changed={:?}", generated_file);
103103
bindings
104104
.write_to_file(&generated_file)
105105
.expect("Couldn't write bindings!");
106106
}
107107

108-
fn generate_file_including_platform_bindings() -> Result<(), std::io::Error> {
109-
let generic_binding_path = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
110-
.join("src")
111-
.join("generated")
112-
.join("bindings.rs");
113-
114-
let mut fh = BufWriter::new(fs::File::create(&generic_binding_path)?);
115-
116-
let platform_bindings = PathBuf::from("src")
117-
.join("generated")
118-
.join(env::var("CARGO_CFG_TARGET_OS").unwrap())
119-
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap())
120-
.join("bindings.rs");
121-
122-
// Build a (relative) path, as a string, to the platform-specific bindings.
123-
// Required so that we can escape backslash (Windows path separators) before
124-
// writing to the file.
125-
let include_path = format!(
126-
"{}{}",
127-
std::path::MAIN_SEPARATOR,
128-
platform_bindings.display()
129-
)
130-
.replace(r#"\"#, r#"\\"#);
131-
fh.write_all(
132-
format!(
133-
r#"include!(concat!(
134-
env!("CARGO_MANIFEST_DIR"),
135-
"{}"
136-
));"#,
137-
include_path
138-
)
139-
.as_bytes(),
140-
)?;
141-
142-
Ok(())
143-
}
144-
145108
fn download<P: AsRef<Path>>(source_url: &str, target_file: P) {
146109
let resp = ureq::get(source_url)
147110
.timeout_connect(1_000) // 1 second
@@ -169,13 +132,13 @@ fn download<P: AsRef<Path>>(source_url: &str, target_file: P) {
169132
}
170133

171134
fn extract_archive(filename: &Path, output: &Path) {
172-
#[cfg(target_family = "unix")]
173-
extract_tgz(filename, output);
174-
#[cfg(target_family = "windows")]
175-
extract_zip(filename, output);
135+
match filename.extension().map(|e| e.to_str()) {
136+
Some(Some("zip")) => extract_zip(filename, output),
137+
Some(Some("tgz")) => extract_tgz(filename, output),
138+
_ => unimplemented!(),
139+
}
176140
}
177141

178-
#[cfg(target_family = "unix")]
179142
fn extract_tgz(filename: &Path, output: &Path) {
180143
let file = fs::File::open(&filename).unwrap();
181144
let buf = io::BufReader::new(file);
@@ -184,13 +147,13 @@ fn extract_tgz(filename: &Path, output: &Path) {
184147
archive.unpack(output).unwrap();
185148
}
186149

187-
#[cfg(target_family = "windows")]
188150
fn extract_zip(filename: &Path, outpath: &Path) {
189151
let file = fs::File::open(&filename).unwrap();
190152
let buf = io::BufReader::new(file);
191153
let mut archive = zip::ZipArchive::new(buf).unwrap();
192154
for i in 0..archive.len() {
193155
let mut file = archive.by_index(i).unwrap();
156+
#[allow(deprecated)]
194157
let outpath = outpath.join(file.sanitized_name());
195158
if !(&*file.name()).ends_with('/') {
196159
println!(
@@ -212,6 +175,7 @@ fn extract_zip(filename: &Path, outpath: &Path) {
212175

213176
fn prebuilt_archive_url() -> (PathBuf, String) {
214177
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
178+
let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Unable to get TARGET_ARCH");
215179

216180
let gpu_str = match env::var(ORT_ENV_GPU) {
217181
Ok(cuda_env) => {
@@ -232,17 +196,28 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
232196
Err(_) => "",
233197
};
234198

235-
let arch_str = match os.as_str() {
236-
"windows" => {
237-
if gpu_str.is_empty() {
238-
"x86"
239-
} else {
240-
"x64"
241-
}
242-
}
243-
_ => "x64",
199+
let arch_str = match arch.as_str() {
200+
"x86_64" => "x64",
201+
"x86" => "x86",
202+
unsupported => panic!("Unsupported architecture {:?}", unsupported),
244203
};
245204

205+
if arch.as_str() == "x86" && os.as_str() != "windows" {
206+
panic!(
207+
"ONNX Runtime only supports x86 (i686) architecture on Windows (not {:?}).",
208+
os
209+
);
210+
}
211+
212+
// Only Windows and Linux x64 support GPU
213+
if !gpu_str.is_empty() {
214+
if arch_str == "x64" && (os == "windows" || os == "linux") {
215+
println!("Supported GPU platform: {} {}", os, arch_str);
216+
} else {
217+
panic!("Unsupported GPU platform: {} {}", os, arch_str);
218+
}
219+
}
220+
246221
let (os_str, archive_extension) = match os.as_str() {
247222
"windows" => ("win", "zip"),
248223
"macos" => ("osx", "tgz"),

0 commit comments

Comments
 (0)