11//! Conversions for packing/unpacking `OrtexTensor`s into different types
22use core:: convert:: TryFrom ;
3+ use half:: { bf16, f16} ;
34use ndarray:: prelude:: * ;
45use ndarray:: { ArrayBase , ArrayView , Data , IxDyn , IxDynImpl , ViewRepr } ;
56use ort:: { DynValue , Error , Value } ;
@@ -26,6 +27,9 @@ pub enum OrtexTensor {
2627 bf16( Array < half:: bf16 , IxDyn > ) ,
2728 f32( Array < f32 , IxDyn > ) ,
2829 f64( Array < f64 , IxDyn > ) ,
30+ // the bool input is for internal use only.
31+ // Any Nx facing ops should panic if called on a bool input
32+ bool( Array < bool , IxDyn > ) ,
2933}
3034
3135impl OrtexTensor {
@@ -43,6 +47,7 @@ impl OrtexTensor {
4347 OrtexTensor :: bf16( y) => y. shape ( ) . to_owned ( ) ,
4448 OrtexTensor :: f32( y) => y. shape ( ) . to_owned ( ) ,
4549 OrtexTensor :: f64( y) => y. shape ( ) . to_owned ( ) ,
50+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
4651 }
4752 }
4853
@@ -108,6 +113,7 @@ impl OrtexTensor {
108113 . into_shape_with_order ( shape)
109114 . map_err ( |e| rustler:: Error :: Term ( Box :: new ( e. to_string ( ) ) ) ) ?,
110115 ) ) ,
116+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
111117 }
112118 }
113119
@@ -125,6 +131,7 @@ impl OrtexTensor {
125131 OrtexTensor :: bf16( _) => ( ortex_atoms:: bf ( ) , 16 ) ,
126132 OrtexTensor :: f32( _) => ( ortex_atoms:: f ( ) , 32 ) ,
127133 OrtexTensor :: f64( _) => ( ortex_atoms:: f ( ) , 64 ) ,
134+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
128135 }
129136 }
130137
@@ -142,6 +149,7 @@ impl OrtexTensor {
142149 OrtexTensor :: bf16( y) => get_bytes ( y) ,
143150 OrtexTensor :: f32( y) => get_bytes ( y) ,
144151 OrtexTensor :: f64( y) => get_bytes ( y) ,
152+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
145153 } ;
146154 contents
147155 }
@@ -173,6 +181,30 @@ impl OrtexTensor {
173181 OrtexTensor :: bf16( y) => OrtexTensor :: bf16 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
174182 OrtexTensor :: f32( y) => OrtexTensor :: f32 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
175183 OrtexTensor :: f64( y) => OrtexTensor :: f64 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
184+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
185+ }
186+ }
187+
188+ pub fn to_bool ( self ) -> OrtexTensor {
189+ match self {
190+ OrtexTensor :: s8( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
191+ OrtexTensor :: s16( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
192+ OrtexTensor :: s32( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
193+ OrtexTensor :: s64( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
194+ OrtexTensor :: u8( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
195+ OrtexTensor :: u16( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
196+ OrtexTensor :: u32( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
197+ OrtexTensor :: u64( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0 ) ) ,
198+ OrtexTensor :: f16( y) => {
199+ OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != f16:: ZERO || x != f16:: NEG_ZERO ) )
200+ }
201+ OrtexTensor :: bf16( y) => OrtexTensor :: bool (
202+ y. to_owned ( )
203+ . mapv ( |x| x != bf16:: ZERO || x != bf16:: NEG_ZERO ) ,
204+ ) ,
205+ OrtexTensor :: f32( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0. ) ) ,
206+ OrtexTensor :: f64( y) => OrtexTensor :: bool ( y. to_owned ( ) . mapv ( |x| x != 0. ) ) ,
207+ _ => panic ! ( "Can't convert this type to bool" ) ,
176208 }
177209 }
178210}
@@ -253,8 +285,10 @@ impl TryFrom<&Value> for OrtexTensor {
253285 ort:: TensorElementType :: String => {
254286 todo ! ( "Can't return string tensors" )
255287 }
288+ // map the output into integer space
256289 ort:: TensorElementType :: Bool => {
257- todo ! ( "Can't return bool tensors" )
290+ let nd_array = e. try_extract_tensor :: < bool > ( ) ?. into_owned ( ) ;
291+ OrtexTensor :: s8 ( nd_array. mapv ( |x| x as i8 ) )
258292 }
259293 } ;
260294
@@ -278,11 +312,32 @@ impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> {
278312 OrtexTensor :: u16( arr) => arr. clone ( ) . try_into ( ) ?,
279313 OrtexTensor :: u32( arr) => arr. clone ( ) . try_into ( ) ?,
280314 OrtexTensor :: u64( arr) => arr. clone ( ) . try_into ( ) ?,
315+ OrtexTensor :: bool( arr) => arr. clone ( ) . try_into ( ) ?,
281316 } ;
282317 Ok ( r. into ( ) )
283318 }
284319}
285320
321+ impl Clone for OrtexTensor {
322+ fn clone ( & self ) -> Self {
323+ match self {
324+ OrtexTensor :: s8( t) => OrtexTensor :: s8 ( t. clone ( ) ) ,
325+ OrtexTensor :: s16( t) => OrtexTensor :: s16 ( t. clone ( ) ) ,
326+ OrtexTensor :: s32( t) => OrtexTensor :: s32 ( t. clone ( ) ) ,
327+ OrtexTensor :: s64( t) => OrtexTensor :: s64 ( t. clone ( ) ) ,
328+ OrtexTensor :: bf16( t) => OrtexTensor :: bf16 ( t. clone ( ) ) ,
329+ OrtexTensor :: f16( t) => OrtexTensor :: f16 ( t. clone ( ) ) ,
330+ OrtexTensor :: f32( t) => OrtexTensor :: f32 ( t. clone ( ) ) ,
331+ OrtexTensor :: f64( t) => OrtexTensor :: f64 ( t. clone ( ) ) ,
332+ OrtexTensor :: u8( t) => OrtexTensor :: u8 ( t. clone ( ) ) ,
333+ OrtexTensor :: u16( t) => OrtexTensor :: u16 ( t. clone ( ) ) ,
334+ OrtexTensor :: u32( t) => OrtexTensor :: u32 ( t. clone ( ) ) ,
335+ OrtexTensor :: u64( t) => OrtexTensor :: u64 ( t. clone ( ) ) ,
336+ OrtexTensor :: bool( t) => OrtexTensor :: bool ( t. clone ( ) ) ,
337+ }
338+ }
339+ }
340+
286341// Currently only supports concatenating tenors of the same type.
287342//
288343// This is a similar structure to the above match clauses, except each function
0 commit comments