Skip to content

Commit cb645a2

Browse files
authored
Merge pull request #38 from gregszumel/rc8_upgrade
Upgrade to rc-8
2 parents 0609a0e + f78d11f commit cb645a2

File tree

7 files changed

+183
-271
lines changed

7 files changed

+183
-271
lines changed

lib/ortex/util.ex

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,5 @@ defmodule Ortex.Util do
3131
|> Enum.map(fn x ->
3232
File.cp!(x, Path.join([destination_dir, Path.basename(x)]))
3333
end)
34-
35-
# Currently ORT doesn't write the .so file we need (fix incoming https://github.com/pykeio/ort/commit/634e49ab7c960782cc2fb83d84cc219e7bb4ae1f),
36-
# so we're hacking a fix here
37-
onnx_runtime_filenames = Enum.map(onnx_runtime_paths, &Path.basename/1)
38-
39-
case "libonnxruntime.so.1.17.0" in onnx_runtime_filenames do
40-
true ->
41-
nil
42-
43-
false ->
44-
File.cp!(
45-
Path.join([destination_dir, "libonnxruntime.so"]),
46-
Path.join([destination_dir, "libonnxruntime.so.1.17.0"])
47-
)
48-
end
4934
end
5035
end

native/ortex/Cargo.lock

Lines changed: 46 additions & 95 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/ortex/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ crate-type = ["cdylib"]
1111

1212
[dependencies]
1313
rustler = "0.29.0"
14-
ort = { version = "2.0.0-rc.0" }
15-
ndarray = "0.15.6"
14+
ort = { version = "2.0.0-rc.8" }
15+
ndarray = "0.16.1"
1616
half = "2.2.1"
1717
tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] }
1818
num-traits = "0.2.15"

native/ortex/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ fn init(
2424
opt: i32,
2525
) -> NifResult<ResourceArc<model::OrtexModel>> {
2626
let eps = utils::map_eps(env, eps);
27-
Ok(ResourceArc::new(
28-
model::init(model_path, eps, opt)
29-
.map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?,
30-
))
27+
let model = model::init(model_path, eps, opt)
28+
.map_err(|e| rustler::Error::Term(Box::new(e.to_string())))?;
29+
Ok(ResourceArc::new(model))
3130
}
3231

3332
#[rustler::nif]

native/ortex/src/model.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
1111
use crate::tensor::OrtexTensor;
1212
use crate::utils::map_opt_level;
13-
use std::convert::{Into, TryFrom};
13+
use std::convert::{TryFrom, TryInto};
1414

15-
use ort::{Error, ExecutionProviderDispatch, Session, Value};
15+
use ort::{Error, ExecutionProviderDispatch, Session};
1616
use rustler::resource::ResourceArc;
1717
use rustler::Atom;
1818

@@ -37,15 +37,11 @@ pub fn init(
3737
) -> Result<OrtexModel, Error> {
3838
// TODO: send tracing logs to erlang/elixir _somehow_
3939
// tracing_subscriber::fmt::init();
40-
ort::init()
41-
.with_execution_providers(&eps)
42-
.with_name("ortex-model")
43-
.commit()?;
4440

4541
let session = Session::builder()?
46-
.with_execution_providers(&eps)?
4742
.with_optimization_level(map_opt_level(opt))?
48-
.with_model_from_file(model_path)?;
43+
.with_execution_providers(eps)?
44+
.commit_from_file(model_path)?;
4945

5046
let state = OrtexModel { session };
5147
Ok(state)
@@ -88,21 +84,22 @@ pub fn run(
8884
inputs: Vec<ResourceArc<OrtexTensor>>,
8985
) -> Result<Vec<(ResourceArc<OrtexTensor>, Vec<usize>, Atom, usize)>, Error> {
9086
// TODO: can we handle an error more elegantly than just .unwrap()?
91-
let final_input: Vec<Value> = inputs
92-
.into_iter()
93-
.map(|x| Value::try_from(&*x).unwrap())
94-
.collect();
87+
88+
let mut ortified_inputs: Vec<ort::SessionInputValue> = Vec::new();
89+
for input in inputs {
90+
let derefed_input: &OrtexTensor = &input;
91+
let v: ort::SessionInputValue = derefed_input.try_into()?;
92+
ortified_inputs.push(v);
93+
}
9594

9695
// Grab the session and run a forward pass with it
97-
let session = &model.session;
96+
let session: &ort::Session = &model.session;
9897

9998
// Construct a Vec of ModelOutput enums based on the DynOrtTensor data type
100-
let outputs = session.run(&final_input[..])?;
101-
99+
let outputs = session.run(&ortified_inputs[..])?;
102100
outputs
103101
.iter()
104102
.map(|(_name, val)| {
105-
let val: &Value = val;
106103
let ortextensor: OrtexTensor = OrtexTensor::try_from(val)?;
107104
let shape = ortextensor.shape();
108105
let (dtype, bits) = ortextensor.dtype();

0 commit comments

Comments
 (0)