@@ -179,19 +179,26 @@ fn generate_bindings(include_dir: &Path) {
179179 . join( "session" )
180180 . display( )
181181 ) ,
182+ #[ cfg( feature = "directml" ) ]
183+ format ! ( "-D{}" , "USE_DML" ) ,
182184 ] ;
183185
186+ #[ cfg( not( feature = "directml" ) ) ]
187+ let header_name = "wrapper.h" ;
188+ #[ cfg( feature = "directml" ) ]
189+ let header_name = "wrapper_directml.h" ;
190+
184191 // Tell cargo to invalidate the built crate whenever the wrapper changes
185- println ! ( "cargo:rerun-if-changed=wrapper.h" ) ;
192+ println ! ( "cargo:rerun-if-changed={}" , header_name ) ;
186193 println ! ( "cargo:rerun-if-changed=src/generated/bindings.rs" ) ;
187194
188195 // The bindgen::Builder is the main entry point
189196 // to bindgen, and lets you build up options for
190197 // the resulting bindings.
191- let bindings = bindgen:: Builder :: default ( )
198+ let mut bind_builder = bindgen:: Builder :: default ( )
192199 // The input header we would like to generate
193200 // bindings for.
194- . header ( "wrapper.h" )
201+ . header ( header_name )
195202 // The current working directory is 'onnxruntime-sys'
196203 . clang_args ( clang_args)
197204 // Tell cargo to invalidate the built crate whenever any of the
@@ -201,19 +208,28 @@ fn generate_bindings(include_dir: &Path) {
201208 . size_t_is_usize ( true )
202209 // Format using rustfmt
203210 . rustfmt_bindings ( true )
204- . rustified_enum ( "*" )
205- // Finish the builder and generate the bindings.
211+ . rustified_enum ( "*" ) ;
212+
213+ for entry in include_dir. read_dir ( ) . unwrap ( ) . filter_map ( |e| e. ok ( ) ) {
214+ let path = entry. path ( ) ;
215+ let file_name = path. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ;
216+ bind_builder =
217+ bind_builder. allowlist_file ( format ! ( ".*{}" , file_name. replace( ".h" , "\\ .h" ) ) ) ;
218+ }
219+ let bindings = bind_builder
206220 . generate ( )
207- // Unwrap the Result and panic on failure.
208221 . expect ( "Unable to generate bindings" ) ;
209222
210223 // Write the bindings to (source controlled) src/generated/<os>/<arch>/bindings.rs
211224 let generated_file = PathBuf :: from ( env:: var ( "CARGO_MANIFEST_DIR" ) . unwrap ( ) )
212225 . join ( "src" )
213226 . join ( "generated" )
214227 . join ( env:: var ( "CARGO_CFG_TARGET_OS" ) . unwrap ( ) )
215- . join ( env:: var ( "CARGO_CFG_TARGET_ARCH" ) . unwrap ( ) )
216- . join ( "bindings.rs" ) ;
228+ . join ( env:: var ( "CARGO_CFG_TARGET_ARCH" ) . unwrap ( ) ) ;
229+ #[ cfg( not( feature = "directml" ) ) ]
230+ let generated_file = generated_file. join ( "bindings.rs" ) ;
231+ #[ cfg( feature = "directml" ) ]
232+ let generated_file = generated_file. join ( "bindings_directml.rs" ) ;
217233 println ! ( "cargo:rerun-if-changed={:?}" , generated_file) ;
218234 bindings
219235 . write_to_file ( & generated_file)
0 commit comments