@@ -30,9 +30,10 @@ pub mod ort_tensor;
3030pub  use  ort_owned_tensor:: { DynOrtTensor ,  OrtOwnedTensor } ; 
3131pub  use  ort_tensor:: OrtTensor ; 
3232
33- use  crate :: { OrtError ,  Result } ; 
33+ use  crate :: tensor:: ort_owned_tensor:: TensorPointerHolder ; 
34+ use  crate :: { error:: call_ort,  OrtError ,  Result } ; 
3435use  onnxruntime_sys:: { self  as  sys,  OnnxEnumInt } ; 
35- use  std:: { fmt,  ptr} ; 
36+ use  std:: { convert :: TryInto   as  _ ,  ffi ,   fmt,  ptr,  rc ,  result ,  string } ; 
3637
3738// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum 
3839// FIXME: Add tests to cover the commented out types 
@@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug {
188189     fn  tensor_element_data_type ( )  -> TensorElementDataType ; 
189190
190191    /// Extract an `ArrayView` from the ort-owned tensor. 
191-      fn  extract_array < ' t ,  D > ( 
192+      fn  extract_data < ' t ,  D > ( 
192193        shape :  D , 
193-         tensor :  * mut  sys:: OrtValue , 
194-     )  -> Result < ndarray:: ArrayView < ' t ,  Self ,  D > > 
194+         tensor_element_len :  usize , 
195+         tensor_ptr :  rc:: Rc < TensorPointerHolder > , 
196+     )  -> Result < TensorData < ' t ,  Self ,  D > > 
195197    where 
196198        D :  ndarray:: Dimension ; 
197199} 
198200
201+ /// Represents the possible ways tensor data can be accessed. 
202+ /// 
203+ /// This should only be used internally. 
204+ #[ derive( Debug ) ]  
205+ pub  enum  TensorData < ' t ,  T ,  D > 
206+ where 
207+     D :  ndarray:: Dimension , 
208+ { 
209+     /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid. 
210+      /// This is used for data types whose in-memory form from ort is compatible with Rust's, like 
211+      /// primitive numeric types. 
212+      TensorPtr  { 
213+         /// The pointer ort produced. Kept alive so that `array_view` is valid. 
214+          ptr :  rc:: Rc < TensorPointerHolder > , 
215+         /// A view into `ptr` 
216+          array_view :  ndarray:: ArrayView < ' t ,  T ,  D > , 
217+     } , 
218+     /// String data is output differently by ort, and of course is also variable size, so it cannot 
219+      /// use the same simple pointer representation. 
220+      // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done. 
221+     Strings  { 
222+         /// Owned Strings copied out of ort's output 
223+          strings :  ndarray:: Array < T ,  D > , 
224+     } , 
225+ } 
226+ 
199227/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` 
200228macro_rules!  impl_prim_type_from_ort_trait { 
201229    ( $type_: ty,  $variant: ident)  => { 
@@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait {
204232                TensorElementDataType :: $variant
205233            } 
206234
207-             fn  extract_array <' t,  D >( 
235+             fn  extract_data <' t,  D >( 
208236                shape:  D , 
209-                 tensor:  * mut  sys:: OrtValue , 
210-             )  -> Result <ndarray:: ArrayView <' t,  Self ,  D >>
237+                 _tensor_element_len:  usize , 
238+                 tensor_ptr:  rc:: Rc <TensorPointerHolder >, 
239+             )  -> Result <TensorData <' t,  Self ,  D >>
211240            where 
212241                D :  ndarray:: Dimension , 
213242            { 
214-                 extract_primitive_array( shape,  tensor) 
243+                 extract_primitive_array( shape,  tensor_ptr. tensor_ptr) . map( |v| { 
244+                     TensorData :: TensorPtr  { 
245+                         ptr:  tensor_ptr, 
246+                         array_view:  v, 
247+                     } 
248+                 } ) 
215249            } 
216250        } 
217251    } ; 
@@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64);
255289impl_prim_type_from_ort_trait ! ( f64 ,  Double ) ; 
256290impl_prim_type_from_ort_trait ! ( u32 ,  Uint32 ) ; 
257291impl_prim_type_from_ort_trait ! ( u64 ,  Uint64 ) ; 
292+ 
293+ impl  TensorDataToType  for  String  { 
294+     fn  tensor_element_data_type ( )  -> TensorElementDataType  { 
295+         TensorElementDataType :: String 
296+     } 
297+ 
298+     fn  extract_data < ' t ,  D :  ndarray:: Dimension > ( 
299+         shape :  D , 
300+         tensor_element_len :  usize , 
301+         tensor_ptr :  rc:: Rc < TensorPointerHolder > , 
302+     )  -> Result < TensorData < ' t ,  Self ,  D > >  { 
303+         // Total length of string data, not including \0 suffix 
304+         let  mut  total_length = 0_u64 ; 
305+         unsafe  { 
306+             call_ort ( |ort| { 
307+                 ort. GetStringTensorDataLength . unwrap ( ) ( tensor_ptr. tensor_ptr ,  & mut  total_length) 
308+             } ) 
309+             . map_err ( OrtError :: GetStringTensorDataLength ) ?
310+         } 
311+ 
312+         // In the JNI impl of this, tensor_element_len was included in addition to total_length, 
313+         // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes 
314+         // don't seem to be written to in practice either. 
315+         // If the string data actually did go farther, it would panic below when using the offset 
316+         // data to get slices for each string. 
317+         let  mut  string_contents = vec ! [ 0_u8 ;  total_length as  usize ] ; 
318+         // one extra slot so that the total length can go in the last one, making all per-string 
319+         // length calculations easy 
320+         let  mut  offsets = vec ! [ 0_u64 ;  tensor_element_len as  usize  + 1 ] ; 
321+ 
322+         unsafe  { 
323+             call_ort ( |ort| { 
324+                 ort. GetStringTensorContent . unwrap ( ) ( 
325+                     tensor_ptr. tensor_ptr , 
326+                     string_contents. as_mut_ptr ( )  as  * mut  ffi:: c_void , 
327+                     total_length, 
328+                     offsets. as_mut_ptr ( ) , 
329+                     tensor_element_len as  u64 , 
330+                 ) 
331+             } ) 
332+             . map_err ( OrtError :: GetStringTensorContent ) ?
333+         } 
334+ 
335+         // final offset = overall length so that per-string length calculations work for the last 
336+         // string 
337+         debug_assert_eq ! ( 0 ,  offsets[ tensor_element_len] ) ; 
338+         offsets[ tensor_element_len]  = total_length; 
339+ 
340+         let  strings = offsets
341+             // offsets has 1 extra offset past the end so that all windows work 
342+             . windows ( 2 ) 
343+             . map ( |w| { 
344+                 let  start:  usize  = w[ 0 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ; 
345+                 let  next_start:  usize  = w[ 1 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ; 
346+ 
347+                 let  slice = & string_contents[ start..next_start] ; 
348+                 String :: from_utf8 ( slice. into ( ) ) 
349+             } ) 
350+             . collect :: < result:: Result < Vec < String > ,  string:: FromUtf8Error > > ( ) 
351+             . map_err ( OrtError :: StringFromUtf8Error ) ?; 
352+ 
353+         let  array = ndarray:: Array :: from_shape_vec ( shape,  strings) 
354+             . expect ( "Shape extracted from tensor didn't match tensor contents" ) ; 
355+ 
356+         Ok ( TensorData :: Strings  {  strings :  array } ) 
357+     } 
358+ } 
0 commit comments