11#![ allow( dead_code) ]
22
3- use io:: Write ;
43use std:: {
54 env, fs,
6- io:: { self , BufWriter , Read } ,
5+ io:: { self , Read , Write } ,
76 path:: { Path , PathBuf } ,
87} ;
98
@@ -36,8 +35,6 @@ const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
3635#[ cfg( feature = "disable-sys-build-script" ) ]
3736fn main ( ) {
3837 println ! ( "Build script disabled!" ) ;
39-
40- generate_file_including_platform_bindings ( ) . unwrap ( ) ;
4138}
4239
4340#[ cfg( not( feature = "disable-sys-build-script" ) ) ]
@@ -59,8 +56,6 @@ fn main() {
5956 println ! ( "cargo:rerun-if-env-changed={}" , ORT_ENV_SYSTEM_LIB_LOCATION ) ;
6057
6158 generate_bindings ( & include_dir) ;
62-
63- generate_file_including_platform_bindings ( ) . unwrap ( ) ;
6459}
6560
6661#[ cfg( not( feature = "generate-bindings" ) ) ]
@@ -75,6 +70,7 @@ fn generate_bindings(include_dir: &Path) {
7570
7671 // Tell cargo to invalidate the built crate whenever the wrapper changes
7772 println ! ( "cargo:rerun-if-changed=wrapper.h" ) ;
73+ println ! ( "cargo:rerun-if-changed=src/generated/bindings.rs" ) ;
7874
7975 // The bindgen::Builder is the main entry point
8076 // to bindgen, and lets you build up options for
@@ -88,6 +84,9 @@ fn generate_bindings(include_dir: &Path) {
8884 // Tell cargo to invalidate the built crate whenever any of the
8985 // included header files changed.
9086 . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks ) )
87+ // Format using rustfmt
88+ . rustfmt_bindings ( true )
89+ . rustified_enum ( "*" )
9190 // Finish the builder and generate the bindings.
9291 . generate ( )
9392 // Unwrap the Result and panic on failure.
@@ -100,48 +99,12 @@ fn generate_bindings(include_dir: &Path) {
10099 . join ( env:: var ( "CARGO_CFG_TARGET_OS" ) . unwrap ( ) )
101100 . join ( env:: var ( "CARGO_CFG_TARGET_ARCH" ) . unwrap ( ) )
102101 . join ( "bindings.rs" ) ;
102+ println ! ( "cargo:rerun-if-changed={:?}" , generated_file) ;
103103 bindings
104104 . write_to_file ( & generated_file)
105105 . expect ( "Couldn't write bindings!" ) ;
106106}
107107
108- fn generate_file_including_platform_bindings ( ) -> Result < ( ) , std:: io:: Error > {
109- let generic_binding_path = PathBuf :: from ( env:: var ( "CARGO_MANIFEST_DIR" ) . unwrap ( ) )
110- . join ( "src" )
111- . join ( "generated" )
112- . join ( "bindings.rs" ) ;
113-
114- let mut fh = BufWriter :: new ( fs:: File :: create ( & generic_binding_path) ?) ;
115-
116- let platform_bindings = PathBuf :: from ( "src" )
117- . join ( "generated" )
118- . join ( env:: var ( "CARGO_CFG_TARGET_OS" ) . unwrap ( ) )
119- . join ( env:: var ( "CARGO_CFG_TARGET_ARCH" ) . unwrap ( ) )
120- . join ( "bindings.rs" ) ;
121-
122- // Build a (relative) path, as a string, to the platform-specific bindings.
123- // Required so that we can escape backslash (Windows path separators) before
124- // writing to the file.
125- let include_path = format ! (
126- "{}{}" ,
127- std:: path:: MAIN_SEPARATOR ,
128- platform_bindings. display( )
129- )
130- . replace ( r#"\"# , r#"\\"# ) ;
131- fh. write_all (
132- format ! (
133- r#"include!(concat!(
134- env!("CARGO_MANIFEST_DIR"),
135- "{}"
136- ));"# ,
137- include_path
138- )
139- . as_bytes ( ) ,
140- ) ?;
141-
142- Ok ( ( ) )
143- }
144-
145108fn download < P : AsRef < Path > > ( source_url : & str , target_file : P ) {
146109 let resp = ureq:: get ( source_url)
147110 . timeout_connect ( 1_000 ) // 1 second
@@ -169,13 +132,13 @@ fn download<P: AsRef<Path>>(source_url: &str, target_file: P) {
169132}
170133
171134fn extract_archive ( filename : & Path , output : & Path ) {
172- #[ cfg( target_family = "unix" ) ]
173- extract_tgz ( filename, output) ;
174- #[ cfg( target_family = "windows" ) ]
175- extract_zip ( filename, output) ;
135+ match filename. extension ( ) . map ( |e| e. to_str ( ) ) {
136+ Some ( Some ( "zip" ) ) => extract_zip ( filename, output) ,
137+ Some ( Some ( "tgz" ) ) => extract_tgz ( filename, output) ,
138+ _ => unimplemented ! ( ) ,
139+ }
176140}
177141
178- #[ cfg( target_family = "unix" ) ]
179142fn extract_tgz ( filename : & Path , output : & Path ) {
180143 let file = fs:: File :: open ( & filename) . unwrap ( ) ;
181144 let buf = io:: BufReader :: new ( file) ;
@@ -184,13 +147,13 @@ fn extract_tgz(filename: &Path, output: &Path) {
184147 archive. unpack ( output) . unwrap ( ) ;
185148}
186149
187- #[ cfg( target_family = "windows" ) ]
188150fn extract_zip ( filename : & Path , outpath : & Path ) {
189151 let file = fs:: File :: open ( & filename) . unwrap ( ) ;
190152 let buf = io:: BufReader :: new ( file) ;
191153 let mut archive = zip:: ZipArchive :: new ( buf) . unwrap ( ) ;
192154 for i in 0 ..archive. len ( ) {
193155 let mut file = archive. by_index ( i) . unwrap ( ) ;
156+ #[ allow( deprecated) ]
194157 let outpath = outpath. join ( file. sanitized_name ( ) ) ;
195158 if !( & * file. name ( ) ) . ends_with ( '/' ) {
196159 println ! (
@@ -212,6 +175,7 @@ fn extract_zip(filename: &Path, outpath: &Path) {
212175
213176fn prebuilt_archive_url ( ) -> ( PathBuf , String ) {
214177 let os = env:: var ( "CARGO_CFG_TARGET_OS" ) . expect ( "Unable to get TARGET_OS" ) ;
178+ let arch = env:: var ( "CARGO_CFG_TARGET_ARCH" ) . expect ( "Unable to get TARGET_ARCH" ) ;
215179
216180 let gpu_str = match env:: var ( ORT_ENV_GPU ) {
217181 Ok ( cuda_env) => {
@@ -232,17 +196,28 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
232196 Err ( _) => "" ,
233197 } ;
234198
235- let arch_str = match os. as_str ( ) {
236- "windows" => {
237- if gpu_str. is_empty ( ) {
238- "x86"
239- } else {
240- "x64"
241- }
242- }
243- _ => "x64" ,
199+ let arch_str = match arch. as_str ( ) {
200+ "x86_64" => "x64" ,
201+ "x86" => "x86" ,
202+ unsupported => panic ! ( "Unsupported architecture {:?}" , unsupported) ,
244203 } ;
245204
205+ if arch. as_str ( ) == "x86" && os. as_str ( ) != "windows" {
206+ panic ! (
207+ "ONNX Runtime only supports x86 (i686) architecture on Windows (not {:?})." ,
208+ os
209+ ) ;
210+ }
211+
212+ // Only Windows and Linux x64 support GPU
213+ if !gpu_str. is_empty ( ) {
214+ if arch_str == "x64" && ( os == "windows" || os == "linux" ) {
215+ println ! ( "Supported GPU platform: {} {}" , os, arch_str) ;
216+ } else {
217+ panic ! ( "Unsupported GPU platform: {} {}" , os, arch_str) ;
218+ }
219+ }
220+
246221 let ( os_str, archive_extension) = match os. as_str ( ) {
247222 "windows" => ( "win" , "zip" ) ,
248223 "macos" => ( "osx" , "tgz" ) ,
0 commit comments