Skip to content

Commit 2da4c1a

Browse files
Make sure EPs are supported
1 parent a4493c3 commit 2da4c1a

File tree

7 files changed

+45
-60
lines changed

7 files changed

+45
-60
lines changed

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ defmodule Ortex.MixProject do
3131
# Run "mix help deps" to learn about dependencies.
3232
defp deps do
3333
[
34-
{:rustler, "~> 0.33"},
34+
{:rustler, "~> 0.36.2"},
3535
{:nx, "~> 0.10"},
3636
{:tokenizers, "~> 0.5", only: :dev},
3737
{:ex_doc, "~> 0.38", only: :dev, runtime: false}

native/ortex/Cargo.lock

Lines changed: 29 additions & 41 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/ortex/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ crate-type = ["cdylib"]
1313
resolver = "2"
1414

1515
[dependencies]
16-
rustler = "0.33"
16+
rustler = "0.36.2"
1717
ort-sys = { version = "=2.0.0-rc.10", default-features = false }
18-
ort = { version = "2.0.0-rc.10", features = ["half"] }
18+
ort = { version = "2.0.0-rc.10", features = ["half", "cuda"] }
1919
ndarray = "0.16.1"
2020
half = "2.6.0"
2121
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }

native/ortex/src/lib.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,13 @@ pub fn concatenate<'a>(
104104
Ok(ResourceArc::new(concatted))
105105
}
106106

107+
pub fn on_load(env: Env) -> bool {
108+
env.register::<OrtexModel>().is_ok() && env.register::<OrtexTensor>().is_ok()
109+
}
110+
107111
rustler::init!(
108112
"Elixir.Ortex.Native",
109-
[
110-
run,
111-
init,
112-
from_binary,
113-
to_binary,
114-
show_session,
115-
slice,
116-
reshape,
117-
concatenate
118-
],
119113
load = |env: Env, _term: Term| -> bool {
120-
rustler::resource!(OrtexModel, env);
121-
rustler::resource!(OrtexTensor, env);
122-
true
114+
on_load(env)
123115
}
124116
);

native/ortex/src/model.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use ort::execution_providers::ExecutionProviderDispatch;
1717
use ort::session::Session;
1818
use ort::Error;
1919
use rustler::Atom;
20+
use rustler::Resource;
2021
use rustler::ResourceArc;
2122
use std::error::Error as StdError;
2223
use std::sync::Mutex;
@@ -26,6 +27,7 @@ use std::sync::Mutex;
2627
pub struct OrtexModel {
2728
pub session: Mutex<ort::session::Session>,
2829
}
30+
impl Resource for OrtexModel {}
2931

3032
// Since we're only using the session for inference and
3133
// inference is threadsafe, this Sync is safe. Additionally,

native/ortex/src/tensor.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ndarray::{ArrayBase, ArrayView, Data, IxDyn, IxDynImpl, ViewRepr};
55
use ort::value::Value;
66
use ort::Error;
77
use rustler::Atom;
8+
use rustler::Resource;
89
use rustler::ResourceArc;
910
use std::error::Error as StdError;
1011

@@ -204,6 +205,8 @@ impl OrtexTensor {
204205
}
205206
}
206207

208+
impl Resource for OrtexTensor {}
209+
207210
fn slice_array<'a, T, D>(
208211
array: &'a Array<T, D>,
209212
slice_specs: &'a Vec<(isize, Option<isize>, isize)>,
@@ -441,7 +444,7 @@ macro_rules! concatenate {
441444
// `typ` is the actual datatype, `ort_tensor_kind` is the OrtexTensor variant
442445
($tensors:expr, $axis:expr, $typ:ty, $ort_tensor_kind:ident) => {{
443446
type ArrayType<'a> = ArrayBase<ViewRepr<&'a $typ>, Dim<IxDynImpl>>;
444-
fn filter(tensor: &OrtexTensor) -> Option<ArrayType> {
447+
fn filter<'a>(tensor: &'a OrtexTensor) -> Option<ArrayType<'a>> {
445448
match tensor {
446449
OrtexTensor::$ort_tensor_kind(x) => Some(x.view()),
447450
_ => None,

native/ortex/src/utils.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use ort::execution_providers::ExecutionProviderDispatch;
1515
use ort::session::builder::GraphOptimizationLevel;
1616

1717
/// A faster (unsafe) way of creating an Array from an Erlang binary
18-
fn initialize_from_raw_ptr<T>(ptr: *const T, shape: &[Ix]) -> ArrayViewMut<T, IxDyn> {
18+
fn initialize_from_raw_ptr<T>(ptr: *const T, shape: &[Ix]) -> ArrayViewMut<'_, T, IxDyn> {
1919
let array = unsafe { ArrayViewMut::from_shape_ptr(shape, ptr as *mut T) };
2020
array
2121
}
@@ -96,7 +96,7 @@ pub fn map_eps(env: rustler::env::Env, eps: Vec<Atom>) -> Vec<ExecutionProviderD
9696
eps.iter()
9797
.map(|e| match &e.to_term(env).atom_to_string().unwrap()[..] {
9898
CPU => ort::execution_providers::cpu::CPUExecutionProvider::default().build(),
99-
CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build(),
99+
CUDA => ort::execution_providers::cuda::CUDAExecutionProvider::default().build().error_on_failure(),
100100
TENSORRT => {
101101
ort::execution_providers::tensorrt::TensorRTExecutionProvider::default().build()
102102
}

0 commit comments

Comments
 (0)