@@ -30,9 +30,10 @@ pub mod ort_tensor;
3030pub use ort_owned_tensor:: { DynOrtTensor , OrtOwnedTensor } ;
3131pub use ort_tensor:: OrtTensor ;
3232
33- use crate :: { OrtError , Result } ;
33+ use crate :: tensor:: ort_owned_tensor:: TensorPointerHolder ;
34+ use crate :: { error:: call_ort, OrtError , Result } ;
3435use onnxruntime_sys:: { self as sys, OnnxEnumInt } ;
35- use std:: { fmt, ptr} ;
36+ use std:: { convert :: TryInto as _ , ffi , fmt, ptr, rc , result , string } ;
3637
3738// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
3839// FIXME: Add tests to cover the commented out types
@@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug {
188189 fn tensor_element_data_type ( ) -> TensorElementDataType ;
189190
190191 /// Extract an `ArrayView` from the ort-owned tensor.
191- fn extract_array < ' t , D > (
192+ fn extract_data < ' t , D > (
192193 shape : D ,
193- tensor : * mut sys:: OrtValue ,
194- ) -> Result < ndarray:: ArrayView < ' t , Self , D > >
194+ tensor_element_len : usize ,
195+ tensor_ptr : rc:: Rc < TensorPointerHolder > ,
196+ ) -> Result < TensorData < ' t , Self , D > >
195197 where
196198 D : ndarray:: Dimension ;
197199}
198200
201+ /// Represents the possible ways tensor data can be accessed.
202+ ///
203+ /// This should only be used internally.
204+ #[ derive( Debug ) ]
205+ pub enum TensorData < ' t , T , D >
206+ where
207+ D : ndarray:: Dimension ,
208+ {
209+ /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid.
210+ /// This is used for data types whose in-memory form from ort is compatible with Rust's, like
211+ /// primitive numeric types.
212+ TensorPtr {
213+ /// The pointer ort produced. Kept alive so that `array_view` is valid.
214+ ptr : rc:: Rc < TensorPointerHolder > ,
215+ /// A view into `ptr`
216+ array_view : ndarray:: ArrayView < ' t , T , D > ,
217+ } ,
218+ /// String data is output differently by ort, and of course is also variable size, so it cannot
219+ /// use the same simple pointer representation.
220+ // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done.
221+ Strings {
222+ /// Owned Strings copied out of ort's output
223+ strings : ndarray:: Array < T , D > ,
224+ } ,
225+ }
226+
199227/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData`
200228macro_rules! impl_prim_type_from_ort_trait {
201229 ( $type_: ty, $variant: ident) => {
@@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait {
204232 TensorElementDataType :: $variant
205233 }
206234
207- fn extract_array <' t, D >(
235+ fn extract_data <' t, D >(
208236 shape: D ,
209- tensor: * mut sys:: OrtValue ,
210- ) -> Result <ndarray:: ArrayView <' t, Self , D >>
237+ _tensor_element_len: usize ,
238+ tensor_ptr: rc:: Rc <TensorPointerHolder >,
239+ ) -> Result <TensorData <' t, Self , D >>
211240 where
212241 D : ndarray:: Dimension ,
213242 {
214- extract_primitive_array( shape, tensor)
243+ extract_primitive_array( shape, tensor_ptr. tensor_ptr) . map( |v| {
244+ TensorData :: TensorPtr {
245+ ptr: tensor_ptr,
246+ array_view: v,
247+ }
248+ } )
215249 }
216250 }
217251 } ;
@@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64);
255289impl_prim_type_from_ort_trait ! ( f64 , Double ) ;
256290impl_prim_type_from_ort_trait ! ( u32 , Uint32 ) ;
257291impl_prim_type_from_ort_trait ! ( u64 , Uint64 ) ;
292+
293+ impl TensorDataToType for String {
294+ fn tensor_element_data_type ( ) -> TensorElementDataType {
295+ TensorElementDataType :: String
296+ }
297+
298+ fn extract_data < ' t , D : ndarray:: Dimension > (
299+ shape : D ,
300+ tensor_element_len : usize ,
301+ tensor_ptr : rc:: Rc < TensorPointerHolder > ,
302+ ) -> Result < TensorData < ' t , Self , D > > {
303+ // Total length of string data, not including \0 suffix
304+ let mut total_length = 0_u64 ;
305+ unsafe {
306+ call_ort ( |ort| {
307+ ort. GetStringTensorDataLength . unwrap ( ) ( tensor_ptr. tensor_ptr , & mut total_length)
308+ } )
309+ . map_err ( OrtError :: GetStringTensorDataLength ) ?
310+ }
311+
312+ // In the JNI impl of this, tensor_element_len was included in addition to total_length,
313+ // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
314+ // don't seem to be written to in practice either.
315+ // If the string data actually did go farther, it would panic below when using the offset
316+ // data to get slices for each string.
317+ let mut string_contents = vec ! [ 0_u8 ; total_length as usize ] ;
318+ // one extra slot so that the total length can go in the last one, making all per-string
319+ // length calculations easy
320+ let mut offsets = vec ! [ 0_u64 ; tensor_element_len as usize + 1 ] ;
321+
322+ unsafe {
323+ call_ort ( |ort| {
324+ ort. GetStringTensorContent . unwrap ( ) (
325+ tensor_ptr. tensor_ptr ,
326+ string_contents. as_mut_ptr ( ) as * mut ffi:: c_void ,
327+ total_length,
328+ offsets. as_mut_ptr ( ) ,
329+ tensor_element_len as u64 ,
330+ )
331+ } )
332+ . map_err ( OrtError :: GetStringTensorContent ) ?
333+ }
334+
335+ // final offset = overall length so that per-string length calculations work for the last
336+ // string
337+ debug_assert_eq ! ( 0 , offsets[ tensor_element_len] ) ;
338+ offsets[ tensor_element_len] = total_length;
339+
340+ let strings = offsets
341+ // offsets has 1 extra offset past the end so that all windows work
342+ . windows ( 2 )
343+ . map ( |w| {
344+ let start: usize = w[ 0 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ;
345+ let next_start: usize = w[ 1 ] . try_into ( ) . expect ( "Offset didn't fit into usize" ) ;
346+
347+ let slice = & string_contents[ start..next_start] ;
348+ String :: from_utf8 ( slice. into ( ) )
349+ } )
350+ . collect :: < result:: Result < Vec < String > , string:: FromUtf8Error > > ( )
351+ . map_err ( OrtError :: StringFromUtf8Error ) ?;
352+
353+ let array = ndarray:: Array :: from_shape_vec ( shape, strings)
354+ . expect ( "Shape extracted from tensor didn't match tensor contents" ) ;
355+
356+ Ok ( TensorData :: Strings { strings : array } )
357+ }
358+ }
0 commit comments