@@ -26,6 +26,9 @@ pub enum OrtexTensor {
2626 bf16( Array < half:: bf16 , IxDyn > ) ,
2727 f32( Array < f32 , IxDyn > ) ,
2828 f64( Array < f64 , IxDyn > ) ,
29+ // the bool input is for internal use only.
30+ // Any Nx facing ops should panic if called on a bool input
31+ bool( Array < bool , IxDyn > ) ,
2932}
3033
3134impl OrtexTensor {
@@ -43,6 +46,7 @@ impl OrtexTensor {
4346 OrtexTensor :: bf16( y) => y. shape ( ) . to_owned ( ) ,
4447 OrtexTensor :: f32( y) => y. shape ( ) . to_owned ( ) ,
4548 OrtexTensor :: f64( y) => y. shape ( ) . to_owned ( ) ,
49+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
4650 }
4751 }
4852
@@ -108,6 +112,7 @@ impl OrtexTensor {
108112 . into_shape_with_order ( shape)
109113 . map_err ( |e| rustler:: Error :: Term ( Box :: new ( e. to_string ( ) ) ) ) ?,
110114 ) ) ,
115+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
111116 }
112117 }
113118
@@ -125,6 +130,7 @@ impl OrtexTensor {
125130 OrtexTensor :: bf16( _) => ( ortex_atoms:: bf ( ) , 16 ) ,
126131 OrtexTensor :: f32( _) => ( ortex_atoms:: f ( ) , 32 ) ,
127132 OrtexTensor :: f64( _) => ( ortex_atoms:: f ( ) , 64 ) ,
133+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
128134 }
129135 }
130136
@@ -142,6 +148,7 @@ impl OrtexTensor {
142148 OrtexTensor :: bf16( y) => get_bytes ( y) ,
143149 OrtexTensor :: f32( y) => get_bytes ( y) ,
144150 OrtexTensor :: f64( y) => get_bytes ( y) ,
151+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
145152 } ;
146153 contents
147154 }
@@ -173,6 +180,25 @@ impl OrtexTensor {
173180 OrtexTensor :: bf16( y) => OrtexTensor :: bf16 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
174181 OrtexTensor :: f32( y) => OrtexTensor :: f32 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
175182 OrtexTensor :: f64( y) => OrtexTensor :: f64 ( slice_array ( y, & slice_specs) . to_owned ( ) ) ,
183+ _ => panic ! ( "Can't convert this type to Nx format" ) ,
184+ }
185+ }
186+
187+ pub fn to_bool ( self ) -> OrtexTensor {
188+ match self {
189+ OrtexTensor :: u8( y) => {
190+ let bool_tensor = y. to_owned ( ) . mapv ( |x| match x {
191+ 0 => false ,
192+ 1 => true ,
193+ _ => {
194+ panic ! (
195+ "Tried to convert a u8 tensor to bool, but not every element is 0 or 1"
196+ )
197+ }
198+ } ) ;
199+ OrtexTensor :: bool ( bool_tensor)
200+ }
201+ t => panic ! ( "Can't convert this type {:?} to bool" , t. dtype( ) ) ,
176202 }
177203 }
178204}
@@ -253,8 +279,10 @@ impl TryFrom<&Value> for OrtexTensor {
253279 ort:: TensorElementType :: String => {
254280 todo ! ( "Can't return string tensors" )
255281 }
282+ // map the output into u8 space
256283 ort:: TensorElementType :: Bool => {
257- todo ! ( "Can't return bool tensors" )
284+ let nd_array = e. try_extract_tensor :: < bool > ( ) ?. into_owned ( ) ;
285+ OrtexTensor :: u8 ( nd_array. mapv ( |x| x as u8 ) )
258286 }
259287 } ;
260288
@@ -278,11 +306,32 @@ impl TryFrom<&OrtexTensor> for ort::SessionInputValue<'_> {
278306 OrtexTensor :: u16( arr) => arr. clone ( ) . try_into ( ) ?,
279307 OrtexTensor :: u32( arr) => arr. clone ( ) . try_into ( ) ?,
280308 OrtexTensor :: u64( arr) => arr. clone ( ) . try_into ( ) ?,
309+ OrtexTensor :: bool( arr) => arr. clone ( ) . try_into ( ) ?,
281310 } ;
282311 Ok ( r. into ( ) )
283312 }
284313}
285314
315+ impl Clone for OrtexTensor {
316+ fn clone ( & self ) -> Self {
317+ match self {
318+ OrtexTensor :: s8( t) => OrtexTensor :: s8 ( t. clone ( ) ) ,
319+ OrtexTensor :: s16( t) => OrtexTensor :: s16 ( t. clone ( ) ) ,
320+ OrtexTensor :: s32( t) => OrtexTensor :: s32 ( t. clone ( ) ) ,
321+ OrtexTensor :: s64( t) => OrtexTensor :: s64 ( t. clone ( ) ) ,
322+ OrtexTensor :: bf16( t) => OrtexTensor :: bf16 ( t. clone ( ) ) ,
323+ OrtexTensor :: f16( t) => OrtexTensor :: f16 ( t. clone ( ) ) ,
324+ OrtexTensor :: f32( t) => OrtexTensor :: f32 ( t. clone ( ) ) ,
325+ OrtexTensor :: f64( t) => OrtexTensor :: f64 ( t. clone ( ) ) ,
326+ OrtexTensor :: u8( t) => OrtexTensor :: u8 ( t. clone ( ) ) ,
327+ OrtexTensor :: u16( t) => OrtexTensor :: u16 ( t. clone ( ) ) ,
328+ OrtexTensor :: u32( t) => OrtexTensor :: u32 ( t. clone ( ) ) ,
329+ OrtexTensor :: u64( t) => OrtexTensor :: u64 ( t. clone ( ) ) ,
330+ OrtexTensor :: bool( t) => OrtexTensor :: bool ( t. clone ( ) ) ,
331+ }
332+ }
333+ }
334+
286335// Currently only supports concatenating tenors of the same type.
287336//
288337// This is a similar structure to the above match clauses, except each function
0 commit comments