Skip to content

Commit bee215a

Browse files
authored
directml用コードのコード生成を行った (#6)
1 parent 3f7fed9 commit bee215a

File tree

15 files changed

+41689
-50266
lines changed

15 files changed

+41689
-50266
lines changed

.github/actions/auto_gen_bind_pr/action.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ inputs:
44
triple:
55
description: "cargo build target"
66
required: true
7+
additional_features:
8+
description: "additional features"
9+
required: false
710
runs:
811
using: "composite"
912
steps:
@@ -22,15 +25,15 @@ runs:
2225
# cargoのキャシュが原因でテストが失敗してることが考えられる場合はバージョン部分を変更する
2326
key: "v1-cargo-test-cache-${{ inputs.triple }}"
2427
- name: gen bind
25-
run: cargo build --target ${{ inputs.triple }} --features generate-bindings && git add --all
28+
run: cargo build -p onnxruntime-sys --target ${{ inputs.triple }} --features generate-bindings,${{ inputs.additional_features }} && git add --all
2629
shell: bash
2730
- name: create_pr
2831
id: cpr
2932
uses: peter-evans/create-pull-request@v4
3033
with:
3134
commit-message: Automated generate bindings for ${{ inputs.triple }}
32-
branch: ${{ steps.extract_branch.outputs.branch }}_auto_gen_bindings_${{ inputs.triple }}
35+
branch: ${{ steps.extract_branch.outputs.branch }}_auto_gen_bindings_${{ inputs.triple }}_${{ inputs.additional_features }}
3336
delete-branch: true
34-
title: Automated generate bindings for ${{ inputs.triple }}
37+
title: Automated generate bindings for ${{ inputs.triple }} ${{ inputs.additional_features }}
3538
body: |
36-
Automated generate bindings for ${{ inputs.triple }}
39+
Automated generate bindings for ${{ inputs.triple }} ${{ inputs.additional_features }}

.github/workflows/gen_bind.yaml

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,30 @@ on:
66
- ".github/workflows/gen_bind.yaml"
77
- ".github/actions/auto_gen_bind_pr/action.yaml"
88
jobs:
9-
generate_bind_x86_64_linux:
10-
runs-on: ubuntu-latest
9+
generate_bind:
10+
strategy:
11+
fail-fast: false
12+
matrix:
13+
include:
14+
- os: ubuntu-latest
15+
triple: x86_64-unknown-linux-gnu
16+
- os: windows-latest
17+
triple: x86_64-pc-windows-msvc
18+
- os: windows-latest
19+
triple: x86_64-pc-windows-msvc
20+
additional_features: directml
21+
- os: windows-latest
22+
triple: i686-pc-windows-msvc
23+
- os: macos-latest
24+
triple: x86_64-apple-darwin
25+
- os: macos-latest
26+
triple: aarch64-apple-darwin
27+
runs-on: ${{ matrix.os }}
1128
steps:
1229
- uses: actions/checkout@v3
1330
with:
1431
ref: ${{ github.head_ref }}
1532
- uses: ./.github/actions/auto_gen_bind_pr
1633
with:
17-
triple: x86_64-unknown-linux-gnu
18-
generate_bind_x86_64_windows:
19-
runs-on: windows-latest
20-
steps:
21-
- uses: actions/checkout@v3
22-
with:
23-
ref: ${{ github.head_ref }}
24-
- name: x86_64_windows
25-
uses: ./.github/actions/auto_gen_bind_pr
26-
with:
27-
triple: x86_64-pc-windows-msvc
28-
generate_bind_x86_windows:
29-
runs-on: windows-latest
30-
steps:
31-
- uses: actions/checkout@v3
32-
with:
33-
ref: ${{ github.head_ref }}
34-
- name: x86_windows
35-
uses: ./.github/actions/auto_gen_bind_pr
36-
with:
37-
triple: i686-pc-windows-msvc
38-
generate_bind_x86_64_mac:
39-
runs-on: macos-latest
40-
steps:
41-
- uses: actions/checkout@v3
42-
with:
43-
ref: ${{ github.head_ref }}
44-
- uses: ./.github/actions/auto_gen_bind_pr
45-
with:
46-
triple: x86_64-apple-darwin
47-
generate_bind_arm64_mac:
48-
runs-on: macos-latest
49-
steps:
50-
- uses: actions/checkout@v3
51-
with:
52-
ref: ${{ github.head_ref }}
53-
- uses: ./.github/actions/auto_gen_bind_pr
54-
with:
55-
triple: aarch64-apple-darwin
34+
triple: ${{ matrix.triple }}
35+
additional_features: ${{ matrix.additional_features }}

.github/workflows/general.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ jobs:
9898
include:
9999
- os: windows-2019
100100
- os: windows-2022
101+
- os: windows-2019
102+
features: directml
103+
- os: windows-2022
104+
features: directml
101105
- os: macos-10.15
102106
- os: macos-11
103107
- os: macos-12
@@ -124,7 +128,7 @@ jobs:
124128
key: "v1-cargo-test-cache-${{ matrix.os }}"
125129
- name: Run cargo test
126130
shell: bash
127-
run: cargo test
131+
run: cargo test --features ,${{ matrix.features }}
128132

129133
clippy:
130134
name: Clippy

onnxruntime-sys/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ keywords = ["neuralnetworks", "onnx", "bindings"]
1818

1919
[build-dependencies]
2020
once_cell = "1.13.0"
21-
bindgen = { version = "0.59", optional = true }
21+
bindgen = { version = "0.60.1", optional = true }
2222
ureq = "2.1"
2323

2424
# Used on Windows

onnxruntime-sys/build.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,26 @@ fn generate_bindings(include_dir: &Path) {
179179
.join("session")
180180
.display()
181181
),
182+
#[cfg(feature = "directml")]
183+
format!("-D{}", "USE_DML"),
182184
];
183185

186+
#[cfg(not(feature = "directml"))]
187+
let header_name = "wrapper.h";
188+
#[cfg(feature = "directml")]
189+
let header_name = "wrapper_directml.h";
190+
184191
// Tell cargo to invalidate the built crate whenever the wrapper changes
185-
println!("cargo:rerun-if-changed=wrapper.h");
192+
println!("cargo:rerun-if-changed={}", header_name);
186193
println!("cargo:rerun-if-changed=src/generated/bindings.rs");
187194

188195
// The bindgen::Builder is the main entry point
189196
// to bindgen, and lets you build up options for
190197
// the resulting bindings.
191-
let bindings = bindgen::Builder::default()
198+
let mut bind_builder = bindgen::Builder::default()
192199
// The input header we would like to generate
193200
// bindings for.
194-
.header("wrapper.h")
201+
.header(header_name)
195202
// The current working directory is 'onnxruntime-sys'
196203
.clang_args(clang_args)
197204
// Tell cargo to invalidate the built crate whenever any of the
@@ -201,19 +208,28 @@ fn generate_bindings(include_dir: &Path) {
201208
.size_t_is_usize(true)
202209
// Format using rustfmt
203210
.rustfmt_bindings(true)
204-
.rustified_enum("*")
205-
// Finish the builder and generate the bindings.
211+
.rustified_enum("*");
212+
213+
for entry in include_dir.read_dir().unwrap().filter_map(|e| e.ok()) {
214+
let path = entry.path();
215+
let file_name = path.file_name().unwrap().to_str().unwrap().to_string();
216+
bind_builder =
217+
bind_builder.allowlist_file(format!(".*{}", file_name.replace(".h", "\\.h")));
218+
}
219+
let bindings = bind_builder
206220
.generate()
207-
// Unwrap the Result and panic on failure.
208221
.expect("Unable to generate bindings");
209222

210223
// Write the bindings to (source controlled) src/generated/<os>/<arch>/bindings.rs
211224
let generated_file = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
212225
.join("src")
213226
.join("generated")
214227
.join(env::var("CARGO_CFG_TARGET_OS").unwrap())
215-
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap())
216-
.join("bindings.rs");
228+
.join(env::var("CARGO_CFG_TARGET_ARCH").unwrap());
229+
#[cfg(not(feature = "directml"))]
230+
let generated_file = generated_file.join("bindings.rs");
231+
#[cfg(feature = "directml")]
232+
let generated_file = generated_file.join("bindings_directml.rs");
217233
println!("cargo:rerun-if-changed={:?}", generated_file);
218234
bindings
219235
.write_to_file(&generated_file)

onnxruntime-sys/src/generated/bindings.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,18 @@ include!(concat!(
2222
"/src/generated/windows/x86/bindings.rs"
2323
));
2424

25-
#[cfg(all(target_os = "windows", target_arch = "x86_64"))]
25+
#[cfg(all(
26+
target_os = "windows",
27+
target_arch = "x86_64",
28+
not(feature = "directml")
29+
))]
2630
include!(concat!(
2731
env!("CARGO_MANIFEST_DIR"),
2832
"/src/generated/windows/x86_64/bindings.rs"
2933
));
34+
35+
#[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "directml"))]
36+
include!(concat!(
37+
env!("CARGO_MANIFEST_DIR"),
38+
"/src/generated/windows/x86_64/bindings_directml.rs"
39+
));

0 commit comments

Comments
 (0)