33
44use crate :: constants:: * ;
55use crate :: tensor:: OrtexTensor ;
6+ use ndarray:: { ArrayViewMut , Ix , IxDyn } ;
67
7- use ndarray:: prelude:: * ;
88use ndarray:: ShapeError ;
99
1010use rustler:: resource:: ResourceArc ;
11- use rustler:: types:: { Binary , OwnedBinary } ;
12- use rustler:: { Atom , Env , Error , NifResult } ;
11+ use rustler:: types:: Binary ;
12+ use rustler:: { Atom , Env , NifResult } ;
1313
1414use ort:: { ExecutionProvider , GraphOptimizationLevel } ;
1515
16+ /// A faster (unsafe) way of creating an Array from an Erlang binary
17+ fn initialize_from_raw_ptr < T > ( ptr : * const T , shape : & [ Ix ] ) -> ArrayViewMut < T , IxDyn > {
18+ let array = unsafe { ArrayViewMut :: from_shape_ptr ( shape, ptr as * mut T ) } ;
19+ array
20+ }
21+
1622/// Given a Binary term, shape, and dtype from the BEAM, constructs an OrtexTensor and
1723/// returns the reference to be used as an Nx.Backend representation.
1824///
@@ -32,115 +38,42 @@ pub fn from_binary(
3238 dtype_str : String ,
3339 dtype_bits : usize ,
3440) -> Result < ResourceArc < OrtexTensor > , ShapeError > {
35- // TODO: make this more DRY, pull out into an impl
3641 match ( dtype_str. as_ref ( ) , dtype_bits) {
3742 ( "bf" , 16 ) => Ok ( ResourceArc :: new ( OrtexTensor :: bf16 (
38- Array :: from_vec (
39- bin. as_slice ( )
40- . chunks_exact ( 2 )
41- . map ( |c| half:: bf16:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] ] ) )
42- . collect ( ) ,
43- )
44- . into_shape ( shape) ?,
43+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const half:: bf16 , & shape) . to_owned ( ) ,
4544 ) ) ) ,
4645 ( "f" , 16 ) => Ok ( ResourceArc :: new ( OrtexTensor :: f16 (
47- Array :: from_vec (
48- bin. as_slice ( )
49- . chunks_exact ( 2 )
50- . map ( |c| half:: f16:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] ] ) )
51- . collect ( ) ,
52- )
53- . into_shape ( shape) ?,
46+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const half:: f16 , & shape) . to_owned ( ) ,
5447 ) ) ) ,
5548 ( "f" , 32 ) => Ok ( ResourceArc :: new ( OrtexTensor :: f32 (
56- Array :: from_vec (
57- bin. as_slice ( )
58- . chunks_exact ( 4 )
59- . map ( |c| f32:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] ] ) )
60- . collect ( ) ,
61- )
62- . into_shape ( shape) ?,
49+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const f32 , & shape) . to_owned ( ) ,
6350 ) ) ) ,
6451 ( "f" , 64 ) => Ok ( ResourceArc :: new ( OrtexTensor :: f64 (
65- Array :: from_vec (
66- bin. as_slice ( )
67- . chunks_exact ( 8 )
68- . map ( |c| f64:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] , c[ 4 ] , c[ 5 ] , c[ 6 ] , c[ 7 ] ] ) )
69- . collect ( ) ,
70- )
71- . into_shape ( shape) ?,
52+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const f64 , & shape) . to_owned ( ) ,
7253 ) ) ) ,
7354 ( "s" , 8 ) => Ok ( ResourceArc :: new ( OrtexTensor :: s8 (
74- Array :: from_vec (
75- bin. as_slice ( )
76- . chunks_exact ( 1 )
77- . map ( |c| i8:: from_ne_bytes ( [ c[ 0 ] ] ) )
78- . collect ( ) ,
79- )
80- . into_shape ( shape) ?,
55+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const i8 , & shape) . to_owned ( ) ,
8156 ) ) ) ,
8257 ( "s" , 16 ) => Ok ( ResourceArc :: new ( OrtexTensor :: s16 (
83- Array :: from_vec (
84- bin. as_slice ( )
85- . chunks_exact ( 2 )
86- . map ( |c| i16:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] ] ) )
87- . collect ( ) ,
88- )
89- . into_shape ( shape) ?,
58+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const i16 , & shape) . to_owned ( ) ,
9059 ) ) ) ,
9160 ( "s" , 32 ) => Ok ( ResourceArc :: new ( OrtexTensor :: s32 (
92- Array :: from_vec (
93- bin. as_slice ( )
94- . chunks_exact ( 4 )
95- . map ( |c| i32:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] ] ) )
96- . collect ( ) ,
97- )
98- . into_shape ( shape) ?,
61+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const i32 , & shape) . to_owned ( ) ,
9962 ) ) ) ,
10063 ( "s" , 64 ) => Ok ( ResourceArc :: new ( OrtexTensor :: s64 (
101- Array :: from_vec (
102- bin. as_slice ( )
103- . chunks_exact ( 8 )
104- . map ( |c| i64:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] , c[ 4 ] , c[ 5 ] , c[ 6 ] , c[ 7 ] ] ) )
105- . collect ( ) ,
106- )
107- . into_shape ( shape) ?,
64+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const i64 , & shape) . to_owned ( ) ,
10865 ) ) ) ,
10966 ( "u" , 8 ) => Ok ( ResourceArc :: new ( OrtexTensor :: u8 (
110- Array :: from_vec (
111- bin. as_slice ( )
112- . chunks_exact ( 1 )
113- . map ( |c| u8:: from_ne_bytes ( [ c[ 0 ] ] ) )
114- . collect ( ) ,
115- )
116- . into_shape ( shape) ?,
67+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const u8 , & shape) . to_owned ( ) ,
11768 ) ) ) ,
11869 ( "u" , 16 ) => Ok ( ResourceArc :: new ( OrtexTensor :: u16 (
119- Array :: from_vec (
120- bin. as_slice ( )
121- . chunks_exact ( 2 )
122- . map ( |c| u16:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] ] ) )
123- . collect ( ) ,
124- )
125- . into_shape ( shape) ?,
70+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const u16 , & shape) . to_owned ( ) ,
12671 ) ) ) ,
12772 ( "u" , 32 ) => Ok ( ResourceArc :: new ( OrtexTensor :: u32 (
128- Array :: from_vec (
129- bin. as_slice ( )
130- . chunks_exact ( 4 )
131- . map ( |c| u32:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] ] ) )
132- . collect ( ) ,
133- )
134- . into_shape ( shape) ?,
73+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const u32 , & shape) . to_owned ( ) ,
13574 ) ) ) ,
13675 ( "u" , 64 ) => Ok ( ResourceArc :: new ( OrtexTensor :: u64 (
137- Array :: from_vec (
138- bin. as_slice ( )
139- . chunks_exact ( 8 )
140- . map ( |c| u64:: from_ne_bytes ( [ c[ 0 ] , c[ 1 ] , c[ 2 ] , c[ 3 ] , c[ 4 ] , c[ 5 ] , c[ 6 ] , c[ 7 ] ] ) )
141- . collect ( ) ,
142- )
143- . into_shape ( shape) ?,
76+ initialize_from_raw_ptr ( bin. as_ptr ( ) as * const u64 , & shape) . to_owned ( ) ,
14477 ) ) ) ,
14578 ( & _, _) => unimplemented ! ( ) ,
14679 }
@@ -154,12 +87,7 @@ pub fn to_binary<'a>(
15487 _bits : usize ,
15588 _limit : usize ,
15689) -> NifResult < Binary < ' a > > {
157- // TODO: implement limit and size so we aren't dumping the entire binary on every
158- // IO.inspect call
159- let bytes = reference. to_bytes ( ) ;
160- let mut bin = OwnedBinary :: new ( bytes. len ( ) ) . ok_or ( Error :: Term ( Box :: new ( "Out of memory" ) ) ) ?;
161- bin. as_mut_slice ( ) . copy_from_slice ( & bytes) ;
162- Ok ( Binary :: from_owned ( bin, env) )
90+ Ok ( reference. make_binary ( env, |x| x. to_bytes ( ) ) )
16391}
16492
16593/// Takes a vec of Atoms and transforms them into a vec of ExecutionProvider Enums
0 commit comments