Skip to content

Commit 34c4c98

Browse files
authored
Merge branch 'elixir-nx:main' into bool_tensors
2 parents 78b0865 + 3eaa913 commit 34c4c98

File tree

7 files changed

+79
-4
lines changed

7 files changed

+79
-4
lines changed

examples/distilbert/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# DistilBert exported to ONNX with HuggingFace transformers
2+
3+
### Running
4+
5+
Run `python export.py` to create the ONNX model for distilbert/distilbert-base-uncased-finetuned-sst-2-english, then `mix run` the `distilbert_classification.exs` script.
6+
7+
### Labels
8+
9+
When exporting the model from huggingface transformers to ONNX, a `config.json` file is added to the chosen directory. This file has the id to label mappings and you can extract them directly to give a label to the input, as shwon in `distilbert_classification.exs`.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defmodule Inference do
2+
def id_to_label(id) do
3+
{:ok, config_json} = File.read("./models/distilbert-onnx/config.json")
4+
{:ok, %{"id2label" => id2label}} = Jason.decode(config_json)
5+
Map.get(id2label, to_string(id))
6+
end
7+
8+
def run() do
9+
model = Ortex.load("./models/distilbert-onnx/model.onnx")
10+
11+
text =
12+
"the movie had a lot of nuance and interesting artistic choices, would like to see more support in the industry for these types of productions"
13+
14+
{:ok, tokenizer} = Tokenizers.Tokenizer.from_file("./models/distilbert-onnx/tokenizer.json")
15+
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
16+
17+
input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)])
18+
mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)])
19+
20+
{output} = Ortex.run(model, {input, mask})
21+
22+
IO.inspect(output)
23+
24+
IO.inspect(
25+
output
26+
|> Nx.backend_transfer()
27+
|> Nx.argmax()
28+
|> Nx.to_number()
29+
|> id_to_label()
30+
)
31+
end
32+
end
33+
34+
Inference.run()

examples/distilbert/export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
### Install dependencies:
3+
4+
$ pip install transformers
5+
$ pip install optimum
6+
$ pip install "transformers[onnx]"
7+
8+
"""
9+
10+
from transformers import DistilBertTokenizer
11+
from optimum.onnxruntime import ORTModelForSequenceClassification
12+
13+
save_directory = "./models/distilbert-onnx/"
14+
15+
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
16+
model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True)
17+
print(model)
18+
19+
model.save_pretrained(save_directory)
20+
tokenizer.save_pretrained(save_directory)

lib/ortex/native.ex

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
defmodule Ortex.Native do
22
@moduledoc false
33

4+
@rustler_version Application.spec(:rustler, :vsn) |> to_string() |> Version.parse!()
5+
46
# We have to compile the crate before `use Rustler` compiles the crate since
57
# cargo downloads the onnxruntime shared libraries and they are not available
68
# to load or copy into Elixir's during the on_load or Elixir compile steps.
79
# In the future, this may be configurable in Rustler.
8-
Rustler.Compiler.compile_crate(__MODULE__, otp_app: :ortex, crate: :ortex)
10+
if Version.compare(@rustler_version, "0.30.0") in [:gt, :eq] do
11+
Rustler.Compiler.compile_crate(:ortex, Application.compile_env(:ortex, __MODULE__, []),
12+
otp_app: :ortex,
13+
crate: :ortex
14+
)
15+
else
16+
Rustler.Compiler.compile_crate(__MODULE__, otp_app: :ortex, crate: :ortex)
17+
end
18+
919
Ortex.Util.copy_ort_libs()
1020

1121
use Rustler,
1222
otp_app: :ortex,
13-
crate: :ortex
23+
crate: :ortex,
24+
skip_compilation?: true
1425

1526
# When loading a NIF module, dummy clauses for all NIF function are required.
1627
# NIF dummies usually just error out when called when the NIF is not loaded, as that should never normally happen.

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ defmodule Ortex.MixProject do
3131
# Run "mix help deps" to learn about dependencies.
3232
defp deps do
3333
[
34-
{:rustler, "~> 0.29.0"},
34+
{:rustler, "~> 0.27"},
3535
{:nx, "~> 0.6"},
3636
{:tokenizers, "~> 0.4", only: :dev},
3737
{:ex_doc, "0.29.4", only: :dev, runtime: false},

python/export_resnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
output_names=["output"],
1616
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
1717
export_params=True,
18+
opset_version=19,
1819
)

python/multi_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ def forward(self, x, y):
3737
"x": {0: "batch_size"},
3838
"y": {0: "batch_size"},
3939
},
40-
opset_version=19
40+
opset_version=19,
4141
)

0 commit comments

Comments
 (0)