Skip to content

Commit da274a9

Browse files
committed
unsafe changes to reduce copies going in/out of BEAM for already
allocated binaries
1 parent 7861cc4 commit da274a9

File tree

4 files changed

+128
-179
lines changed

4 files changed

+128
-179
lines changed

native/ortex/src/tensor.rs

Lines changed: 26 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Conversions for packing/unpacking `OrtexTensor`s into different types
22
use ndarray::prelude::*;
3+
use ndarray::Data;
34
use ort::tensor::{DynOrtTensor, FromArray, InputTensor, TensorElementDataType};
45
use ort::OrtError;
56
use rustler::Atom;
@@ -78,88 +79,36 @@ impl OrtexTensor {
7879
}
7980
}
8081

81-
pub fn to_bytes(&self) -> Vec<u8> {
82-
// Annoying and not DRY, Traits are probably the answer here...
83-
// Once num_traits has from_<endian>_bytes we can pull that in
84-
// https://github.com/rust-num/num-traits/pull/224
85-
let contents = match self {
86-
OrtexTensor::s8(y) => y
87-
.clone()
88-
.into_raw_vec()
89-
.iter()
90-
.flat_map(|f| f.to_ne_bytes().to_vec())
91-
.collect(),
92-
OrtexTensor::s16(y) => y
93-
.clone()
94-
.into_raw_vec()
95-
.iter()
96-
.flat_map(|f| f.to_ne_bytes().to_vec())
97-
.collect(),
98-
OrtexTensor::s32(y) => y
99-
.clone()
100-
.into_raw_vec()
101-
.iter()
102-
.flat_map(|f| f.to_ne_bytes().to_vec())
103-
.collect(),
104-
OrtexTensor::s64(y) => y
105-
.clone()
106-
.into_raw_vec()
107-
.iter()
108-
.flat_map(|f| f.to_ne_bytes().to_vec())
109-
.collect(),
110-
OrtexTensor::u8(y) => y
111-
.clone()
112-
.into_raw_vec()
113-
.iter()
114-
.flat_map(|f| f.to_ne_bytes().to_vec())
115-
.collect(),
116-
OrtexTensor::u16(y) => y
117-
.clone()
118-
.into_raw_vec()
119-
.iter()
120-
.flat_map(|f| f.to_ne_bytes().to_vec())
121-
.collect(),
122-
OrtexTensor::u32(y) => y
123-
.clone()
124-
.into_raw_vec()
125-
.iter()
126-
.flat_map(|f| f.to_ne_bytes().to_vec())
127-
.collect(),
128-
OrtexTensor::u64(y) => y
129-
.clone()
130-
.into_raw_vec()
131-
.iter()
132-
.flat_map(|f| f.to_ne_bytes().to_vec())
133-
.collect(),
134-
OrtexTensor::f16(y) => y
135-
.clone()
136-
.into_raw_vec()
137-
.iter()
138-
.flat_map(|f| f.to_ne_bytes().to_vec())
139-
.collect(),
140-
OrtexTensor::bf16(y) => y
141-
.clone()
142-
.into_raw_vec()
143-
.iter()
144-
.flat_map(|f| f.to_ne_bytes().to_vec())
145-
.collect(),
146-
OrtexTensor::f32(y) => y
147-
.clone()
148-
.into_raw_vec()
149-
.iter()
150-
.flat_map(|f| f.to_ne_bytes().to_vec())
151-
.collect(),
152-
OrtexTensor::f64(y) => y
153-
.clone()
154-
.into_raw_vec()
155-
.iter()
156-
.flat_map(|f| f.to_ne_bytes().to_vec())
157-
.collect(),
82+
pub fn to_bytes<'a>(&'a self) -> &'a [u8] {
83+
let contents: &'a [u8] = match self {
84+
OrtexTensor::s8(y) => get_bytes(y),
85+
OrtexTensor::s16(y) => get_bytes(y),
86+
OrtexTensor::s32(y) => get_bytes(y),
87+
OrtexTensor::s64(y) => get_bytes(y),
88+
OrtexTensor::u8(y) => get_bytes(y),
89+
OrtexTensor::u16(y) => get_bytes(y),
90+
OrtexTensor::u32(y) => get_bytes(y),
91+
OrtexTensor::u64(y) => get_bytes(y),
92+
OrtexTensor::f16(y) => get_bytes(y),
93+
OrtexTensor::bf16(y) => get_bytes(y),
94+
OrtexTensor::f32(y) => get_bytes(y),
95+
OrtexTensor::f64(y) => get_bytes(y),
15896
};
15997
contents
16098
}
16199
}
162100

101+
fn get_bytes<'a, T>(array: &'a ArrayBase<T, IxDyn>) -> &'a [u8]
102+
where
103+
T: Data,
104+
{
105+
let len = array.len();
106+
let binding = unsafe { std::mem::zeroed() };
107+
let f = array.get(0).unwrap_or(&binding);
108+
let size: usize = std::mem::size_of_val(f);
109+
unsafe { std::slice::from_raw_parts(array.as_ptr() as *const u8, len * size) }
110+
}
111+
163112
impl std::convert::TryFrom<&DynOrtTensor<'_, IxDyn>> for OrtexTensor {
164113
type Error = OrtError;
165114
fn try_from(e: &DynOrtTensor<IxDyn>) -> Result<OrtexTensor, Self::Error> {

native/ortex/src/utils.rs

Lines changed: 22 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,22 @@
33
44
use crate::constants::*;
55
use crate::tensor::OrtexTensor;
6+
use ndarray::{ArrayViewMut, Ix, IxDyn};
67

7-
use ndarray::prelude::*;
88
use ndarray::ShapeError;
99

1010
use rustler::resource::ResourceArc;
11-
use rustler::types::{Binary, OwnedBinary};
12-
use rustler::{Atom, Env, Error, NifResult};
11+
use rustler::types::Binary;
12+
use rustler::{Atom, Env, NifResult};
1313

1414
use ort::{ExecutionProvider, GraphOptimizationLevel};
1515

16+
/// A faster (unsafe) way of creating an Array from an Erlang binary
17+
fn initialize_from_raw_ptr<T>(ptr: *const T, shape: &[Ix]) -> ArrayViewMut<T, IxDyn> {
18+
let array = unsafe { ArrayViewMut::from_shape_ptr(shape, ptr as *mut T) };
19+
array
20+
}
21+
1622
/// Given a Binary term, shape, and dtype from the BEAM, constructs an OrtexTensor and
1723
/// returns the reference to be used as an Nx.Backend representation.
1824
///
@@ -32,115 +38,42 @@ pub fn from_binary(
3238
dtype_str: String,
3339
dtype_bits: usize,
3440
) -> Result<ResourceArc<OrtexTensor>, ShapeError> {
35-
// TODO: make this more DRY, pull out into an impl
3641
match (dtype_str.as_ref(), dtype_bits) {
3742
("bf", 16) => Ok(ResourceArc::new(OrtexTensor::bf16(
38-
Array::from_vec(
39-
bin.as_slice()
40-
.chunks_exact(2)
41-
.map(|c| half::bf16::from_ne_bytes([c[0], c[1]]))
42-
.collect(),
43-
)
44-
.into_shape(shape)?,
43+
initialize_from_raw_ptr(bin.as_ptr() as *const half::bf16, &shape).to_owned(),
4544
))),
4645
("f", 16) => Ok(ResourceArc::new(OrtexTensor::f16(
47-
Array::from_vec(
48-
bin.as_slice()
49-
.chunks_exact(2)
50-
.map(|c| half::f16::from_ne_bytes([c[0], c[1]]))
51-
.collect(),
52-
)
53-
.into_shape(shape)?,
46+
initialize_from_raw_ptr(bin.as_ptr() as *const half::f16, &shape).to_owned(),
5447
))),
5548
("f", 32) => Ok(ResourceArc::new(OrtexTensor::f32(
56-
Array::from_vec(
57-
bin.as_slice()
58-
.chunks_exact(4)
59-
.map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
60-
.collect(),
61-
)
62-
.into_shape(shape)?,
49+
initialize_from_raw_ptr(bin.as_ptr() as *const f32, &shape).to_owned(),
6350
))),
6451
("f", 64) => Ok(ResourceArc::new(OrtexTensor::f64(
65-
Array::from_vec(
66-
bin.as_slice()
67-
.chunks_exact(8)
68-
.map(|c| f64::from_ne_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
69-
.collect(),
70-
)
71-
.into_shape(shape)?,
52+
initialize_from_raw_ptr(bin.as_ptr() as *const f64, &shape).to_owned(),
7253
))),
7354
("s", 8) => Ok(ResourceArc::new(OrtexTensor::s8(
74-
Array::from_vec(
75-
bin.as_slice()
76-
.chunks_exact(1)
77-
.map(|c| i8::from_ne_bytes([c[0]]))
78-
.collect(),
79-
)
80-
.into_shape(shape)?,
55+
initialize_from_raw_ptr(bin.as_ptr() as *const i8, &shape).to_owned(),
8156
))),
8257
("s", 16) => Ok(ResourceArc::new(OrtexTensor::s16(
83-
Array::from_vec(
84-
bin.as_slice()
85-
.chunks_exact(2)
86-
.map(|c| i16::from_ne_bytes([c[0], c[1]]))
87-
.collect(),
88-
)
89-
.into_shape(shape)?,
58+
initialize_from_raw_ptr(bin.as_ptr() as *const i16, &shape).to_owned(),
9059
))),
9160
("s", 32) => Ok(ResourceArc::new(OrtexTensor::s32(
92-
Array::from_vec(
93-
bin.as_slice()
94-
.chunks_exact(4)
95-
.map(|c| i32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
96-
.collect(),
97-
)
98-
.into_shape(shape)?,
61+
initialize_from_raw_ptr(bin.as_ptr() as *const i32, &shape).to_owned(),
9962
))),
10063
("s", 64) => Ok(ResourceArc::new(OrtexTensor::s64(
101-
Array::from_vec(
102-
bin.as_slice()
103-
.chunks_exact(8)
104-
.map(|c| i64::from_ne_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
105-
.collect(),
106-
)
107-
.into_shape(shape)?,
64+
initialize_from_raw_ptr(bin.as_ptr() as *const i64, &shape).to_owned(),
10865
))),
10966
("u", 8) => Ok(ResourceArc::new(OrtexTensor::u8(
110-
Array::from_vec(
111-
bin.as_slice()
112-
.chunks_exact(1)
113-
.map(|c| u8::from_ne_bytes([c[0]]))
114-
.collect(),
115-
)
116-
.into_shape(shape)?,
67+
initialize_from_raw_ptr(bin.as_ptr() as *const u8, &shape).to_owned(),
11768
))),
11869
("u", 16) => Ok(ResourceArc::new(OrtexTensor::u16(
119-
Array::from_vec(
120-
bin.as_slice()
121-
.chunks_exact(2)
122-
.map(|c| u16::from_ne_bytes([c[0], c[1]]))
123-
.collect(),
124-
)
125-
.into_shape(shape)?,
70+
initialize_from_raw_ptr(bin.as_ptr() as *const u16, &shape).to_owned(),
12671
))),
12772
("u", 32) => Ok(ResourceArc::new(OrtexTensor::u32(
128-
Array::from_vec(
129-
bin.as_slice()
130-
.chunks_exact(4)
131-
.map(|c| u32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
132-
.collect(),
133-
)
134-
.into_shape(shape)?,
73+
initialize_from_raw_ptr(bin.as_ptr() as *const u32, &shape).to_owned(),
13574
))),
13675
("u", 64) => Ok(ResourceArc::new(OrtexTensor::u64(
137-
Array::from_vec(
138-
bin.as_slice()
139-
.chunks_exact(8)
140-
.map(|c| u64::from_ne_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
141-
.collect(),
142-
)
143-
.into_shape(shape)?,
76+
initialize_from_raw_ptr(bin.as_ptr() as *const u64, &shape).to_owned(),
14477
))),
14578
(&_, _) => unimplemented!(),
14679
}
@@ -154,12 +87,7 @@ pub fn to_binary<'a>(
15487
_bits: usize,
15588
_limit: usize,
15689
) -> NifResult<Binary<'a>> {
157-
// TODO: implement limit and size so we aren't dumping the entire binary on every
158-
// IO.inspect call
159-
let bytes = reference.to_bytes();
160-
let mut bin = OwnedBinary::new(bytes.len()).ok_or(Error::Term(Box::new("Out of memory")))?;
161-
bin.as_mut_slice().copy_from_slice(&bytes);
162-
Ok(Binary::from_owned(bin, env))
90+
Ok(reference.make_binary(env, |x| x.to_bytes()))
16391
}
16492

16593
/// Takes a vec of Atoms and transforms them into a vec of ExecutionProvider Enums

test/dtype/dtype_test.exs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
defmodule Ortex.TestDtypes do
2+
use ExUnit.Case
3+
4+
{tensor, _} = Nx.Random.uniform(Nx.Random.key(42), 0, 256, shape: {100, 100})
5+
@tensor tensor
6+
7+
defp bin_binary(dtype) do
8+
%{data: %{state: bin}} = @tensor |> Nx.as_type(dtype)
9+
bin
10+
end
11+
12+
defp bin_ortex(dtype) do
13+
%{data: %{state: bin}} =
14+
@tensor
15+
|> Nx.as_type(dtype)
16+
|> Nx.backend_transfer(Ortex.Backend)
17+
|> Nx.backend_transfer(Nx.BinaryBackend)
18+
19+
bin
20+
end
21+
22+
test "size 0 tensor" do
23+
%{data: %{state: bin1}} = Nx.tensor(0)
24+
25+
%{data: %{state: bin2}} =
26+
Nx.tensor(0)
27+
|> Nx.backend_transfer(Ortex.Backend)
28+
|> Nx.backend_transfer(Nx.BinaryBackend)
29+
30+
assert bin1 == bin2
31+
end
32+
33+
test "u8 conversion" do
34+
assert bin_binary(:u8) == bin_ortex(:u8)
35+
end
36+
37+
test "u16 conversion" do
38+
assert bin_binary(:u16) == bin_ortex(:u16)
39+
end
40+
41+
test "u32 conversion" do
42+
assert bin_binary(:u32) == bin_ortex(:u32)
43+
end
44+
45+
test "u64 conversion" do
46+
assert bin_binary(:u64) == bin_ortex(:u64)
47+
end
48+
49+
test "s8 conversion" do
50+
assert bin_binary(:s8) == bin_ortex(:s8)
51+
end
52+
53+
test "s16 conversion" do
54+
assert bin_binary(:s16) == bin_ortex(:s16)
55+
end
56+
57+
test "s32 conversion" do
58+
assert bin_binary(:s32) == bin_ortex(:s32)
59+
end
60+
61+
test "s64 conversion" do
62+
assert bin_binary(:s64) == bin_ortex(:s64)
63+
end
64+
65+
test "f16 conversion" do
66+
assert bin_binary(:f16) == bin_ortex(:f16)
67+
end
68+
69+
test "bf16 conversion" do
70+
assert bin_binary(:bf16) == bin_ortex(:bf16)
71+
end
72+
73+
test "f32 conversion" do
74+
assert bin_binary(:f32) == bin_ortex(:f32)
75+
end
76+
77+
test "f64 conversion" do
78+
assert bin_binary(:f64) == bin_ortex(:f64)
79+
end
80+
end

test/ortex_test.exs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,6 @@ defmodule OrtexTest do
1212
assert argmax == Nx.tensor([499])
1313
end
1414

15-
test "transfer to Ortex.Backend" do
16-
assert true
17-
end
18-
19-
test "transfer from Ortex.Backend" do
20-
assert true
21-
end
22-
2315
test "Nx.Serving with resnet50" do
2416
model = Ortex.load("./models/resnet50.onnx")
2517

0 commit comments

Comments
 (0)