Skip to content

Commit 3d4a98b

Browse files
authored
Merge pull request #4 from relaypro-open/rust_unsafe
Rust unsafe
2 parents 872144e + da274a9 commit 3d4a98b

File tree

9 files changed

+157
-198
lines changed

9 files changed

+157
-198
lines changed

lib/ortex/util.ex

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
defmodule Ortex.Util do
22
def copy_ort_libs() do
3-
4-
build_root = Path.absname(:code.priv_dir(:ortex)) |> Path.dirname
3+
build_root = Path.absname(:code.priv_dir(:ortex)) |> Path.dirname()
54

65
rust_env =
7-
case Path.join([build_root, "native/ortex/release"]) |> File.ls do
6+
case Path.join([build_root, "native/ortex/release"]) |> File.ls() do
87
{:ok, _} -> "release"
98
_ -> "debug"
109
end

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ defmodule Ortex.MixProject do
3131
# Run "mix help deps" to learn about dependencies.
3232
defp deps do
3333
[
34-
{:rustler, "~> 0.26.0"},
34+
{:rustler, "~> 0.28.0"},
3535
{:nx, "~>0.5.3"},
3636
{:tokenizers, "~> 0.3.0", only: :dev},
3737
{:ex_doc, "0.29.4", only: :dev, runtime: false},

mix.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
1919
"nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"},
2020
"protox": {:hex, :protox, "1.6.10", "41d0b0c5b9190e7d5e6a2b1a03a09257ead6f3d95e6a0cf8b81430b526126908", [:mix], [{:decimal, "~> 1.9 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "9769fca26ae7abfc5cc61308a1e8d9e2400ff89a799599cee7930d21132832d9"},
21-
"rustler": {:hex, :rustler, "0.26.0", "06a2773d453ee3e9109efda643cf2ae633dedea709e2455ac42b83637c9249bf", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "42961e9d2083d004d5a53e111ad1f0c347efd9a05cb2eb2ffa1d037cdc74db91"},
21+
"rustler": {:hex, :rustler, "0.28.0", "b8e2c43013e12dd06f61dcf87033d2e2c8245feddb121b82179c923be31ad319", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "99f811f58c993f0343851adb0af589a99cfd3dc20f2efb8ef08d1a8447980b98"},
2222
"rustler_precompiled": {:hex, :rustler_precompiled, "0.6.1", "160b545bce8bf9a3f1b436b2c10f53574036a0db628e40f393328cbbe593602f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "0dd269fa261c4e3df290b12031c575fff07a542749f7b0e8b744d72d66c43600"},
2323
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
2424
"tokenizers": {:hex, :tokenizers, "0.3.2", "78c6238690a0467c613c8ba3c59338235594a78f870e8f8151b9614516dee0fd", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "f6dd9a798e81cf2f3359e1731836ed0a351cae4da5d5d570a7ef3d0543e9cf85"},

native/ortex/Cargo.lock

Lines changed: 24 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/ortex/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ path = "src/lib.rs"
1010
crate-type = ["cdylib"]
1111

1212
[dependencies]
13-
rustler = "0.26.0"
13+
rustler = "0.28.0"
1414
ort = {version = "1.14.6", default-features = false, features = ["half", "copy-dylibs"]}
1515
ndarray = "0.15.6"
1616
half = "2.2.1"

native/ortex/src/tensor.rs

Lines changed: 26 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3+
use ndarray::Data;
34
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
45
use ort::OrtError;
56
use rustler::Atom;
@@ -78,88 +79,36 @@ impl OrtexTensor {
7879
}
7980
}
8081

81-
pub fn to_bytes(&self) -> Vec<u8> {
82-
// Annoying and not DRY, Traits are probably the answer here...
83-
// Once num_traits has from_<endian>_bytes we can pull that in
84-
// https://github.com/rust-num/num-traits/pull/224
85-
let contents = match self {
86-
OrtexTensor::s8(y) => y
87-
.clone()
88-
.into_raw_vec()
89-
.iter()
90-
.flat_map(|f| f.to_ne_bytes().to_vec())
91-
.collect(),
92-
OrtexTensor::s16(y) => y
93-
.clone()
94-
.into_raw_vec()
95-
.iter()
96-
.flat_map(|f| f.to_ne_bytes().to_vec())
97-
.collect(),
98-
OrtexTensor::s32(y) => y
99-
.clone()
100-
.into_raw_vec()
101-
.iter()
102-
.flat_map(|f| f.to_ne_bytes().to_vec())
103-
.collect(),
104-
OrtexTensor::s64(y) => y
105-
.clone()
106-
.into_raw_vec()
107-
.iter()
108-
.flat_map(|f| f.to_ne_bytes().to_vec())
109-
.collect(),
110-
OrtexTensor::u8(y) => y
111-
.clone()
112-
.into_raw_vec()
113-
.iter()
114-
.flat_map(|f| f.to_ne_bytes().to_vec())
115-
.collect(),
116-
OrtexTensor::u16(y) => y
117-
.clone()
118-
.into_raw_vec()
119-
.iter()
120-
.flat_map(|f| f.to_ne_bytes().to_vec())
121-
.collect(),
122-
OrtexTensor::u32(y) => y
123-
.clone()
124-
.into_raw_vec()
125-
.iter()
126-
.flat_map(|f| f.to_ne_bytes().to_vec())
127-
.collect(),
128-
OrtexTensor::u64(y) => y
129-
.clone()
130-
.into_raw_vec()
131-
.iter()
132-
.flat_map(|f| f.to_ne_bytes().to_vec())
133-
.collect(),
134-
OrtexTensor::f16(y) => y
135-
.clone()
136-
.into_raw_vec()
137-
.iter()
138-
.flat_map(|f| f.to_ne_bytes().to_vec())
139-
.collect(),
140-
OrtexTensor::bf16(y) => y
141-
.clone()
142-
.into_raw_vec()
143-
.iter()
144-
.flat_map(|f| f.to_ne_bytes().to_vec())
145-
.collect(),
146-
OrtexTensor::f32(y) => y
147-
.clone()
148-
.into_raw_vec()
149-
.iter()
150-
.flat_map(|f| f.to_ne_bytes().to_vec())
151-
.collect(),
152-
OrtexTensor::f64(y) => y
153-
.clone()
154-
.into_raw_vec()
155-
.iter()
156-
.flat_map(|f| f.to_ne_bytes().to_vec())
157-
.collect(),
82+
pub fn to_bytes<'a>(&'a self) -> &'a [u8] {
83+
let contents: &'a [u8] = match self {
84+
OrtexTensor::s8(y) => get_bytes(y),
85+
OrtexTensor::s16(y) => get_bytes(y),
86+
OrtexTensor::s32(y) => get_bytes(y),
87+
OrtexTensor::s64(y) => get_bytes(y),
88+
OrtexTensor::u8(y) => get_bytes(y),
89+
OrtexTensor::u16(y) => get_bytes(y),
90+
OrtexTensor::u32(y) => get_bytes(y),
91+
OrtexTensor::u64(y) => get_bytes(y),
92+
OrtexTensor::f16(y) => get_bytes(y),
93+
OrtexTensor::bf16(y) => get_bytes(y),
94+
OrtexTensor::f32(y) => get_bytes(y),
95+
OrtexTensor::f64(y) => get_bytes(y),
15896
};
15997
contents
16098
}
16199
}
162100

101+
fn get_bytes<'a, T>(array: &'a ArrayBase<T, IxDyn>) -> &'a [u8]
102+
where
103+
T: Data,
104+
{
105+
let len = array.len();
106+
let binding = unsafe { std::mem::zeroed() };
107+
let f = array.get(0).unwrap_or(&binding);
108+
let size: usize = std::mem::size_of_val(f);
109+
unsafe { std::slice::from_raw_parts(array.as_ptr() as *const u8, len * size) }
110+
}
111+
163112
impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
164113
type Error = OrtError;
165114
fn try_from(e: &DynOrtTensor<IxDyn>) -> Result<OrtexTensor, Self::Error> {

0 commit comments

Comments
 (0)