@@ -405,32 +405,8 @@ impl<'a> Session<'a> {
405405 . map ( |n| n. as_ptr ( ) as * const i8 )
406406 . collect ( ) ;
407407
408- let output_shapes: Vec < Vec < usize > > = {
409- let mut tmp = Vec :: new ( ) ;
410- for ( idx, output) in self . outputs . iter ( ) . enumerate ( ) {
411- let v: Vec < _ > = output
412- . dimensions
413- . iter ( )
414- . enumerate ( )
415- . map ( |( jdx, dim) | match dim {
416- None => input_arrays[ idx] . shape ( ) [ jdx] ,
417- Some ( d) => * d as usize ,
418- } )
419- . collect ( ) ;
420- tmp. push ( v) ;
421- }
422- tmp
423- } ;
424- let memory_info_ref = & self . memory_info ;
425- let output_tensor_extractors: Vec < OrtOwnedTensorExtractor < ndarray:: IxDyn > > = output_shapes
426- . iter ( )
427- . map ( |output_shape| {
428- OrtOwnedTensorExtractor :: new ( memory_info_ref, ndarray:: IxDyn ( output_shape) )
429- } )
430- . collect ( ) ;
431-
432408 let mut output_tensor_extractors_ptrs: Vec < * mut sys:: OrtValue > =
433- vec ! [ std:: ptr:: null_mut( ) ; output_tensor_extractors . len( ) ] ;
409+ vec ! [ std:: ptr:: null_mut( ) ; self . outputs . len( ) ] ;
434410
435411 // The C API expects pointers for the arrays (pointers to C-arrays)
436412 let input_ort_tensors: Vec < OrtTensor < TIn , D > > = input_arrays
@@ -458,11 +434,23 @@ impl<'a> Session<'a> {
458434 } ;
459435 status_to_result ( status) . map_err ( OrtError :: Run ) ?;
460436
437+ let memory_info_ref = & self . memory_info ;
461438 let outputs: Result < Vec < OrtOwnedTensor < TOut , ndarray:: Dim < ndarray:: IxDynImpl > > > > =
462- output_tensor_extractors
439+ output_tensor_extractors_ptrs
463440 . into_iter ( )
464- . zip ( output_tensor_extractors_ptrs. into_iter ( ) )
465- . map ( |( mut output_tensor_extractor, ptr) | {
441+ . map ( |ptr| {
442+ let mut tensor_info_ptr: * mut sys:: OrtTensorTypeAndShapeInfo =
443+ std:: ptr:: null_mut ( ) ;
444+ let status = unsafe {
445+ g_ort ( ) . GetTensorTypeAndShape . unwrap ( ) ( ptr, & mut tensor_info_ptr as _ )
446+ } ;
447+ status_to_result ( status) . map_err ( OrtError :: GetTensorTypeAndShape ) ?;
448+ let dims = unsafe { get_tensor_dimensions ( tensor_info_ptr) } ;
449+ unsafe { g_ort ( ) . ReleaseTensorTypeAndShapeInfo . unwrap ( ) ( tensor_info_ptr) } ;
450+ let dims: Vec < _ > = dims?. iter ( ) . map ( |& n| n as usize ) . collect ( ) ;
451+
452+ let mut output_tensor_extractor =
453+ OrtOwnedTensorExtractor :: new ( memory_info_ref, ndarray:: IxDyn ( & dims) ) ;
466454 output_tensor_extractor. tensor_ptr = ptr;
467455 output_tensor_extractor. extract :: < TOut > ( )
468456 } )
@@ -560,6 +548,24 @@ impl<'a> Session<'a> {
560548 }
561549}
562550
551+ unsafe fn get_tensor_dimensions (
552+ tensor_info_ptr : * const sys:: OrtTensorTypeAndShapeInfo ,
553+ ) -> Result < Vec < i64 > > {
554+ let mut num_dims = 0 ;
555+ let status = g_ort ( ) . GetDimensionsCount . unwrap ( ) ( tensor_info_ptr, & mut num_dims) ;
556+ status_to_result ( status) . map_err ( OrtError :: GetDimensionsCount ) ?;
557+ assert_ne ! ( num_dims, 0 ) ;
558+
559+ let mut node_dims: Vec < i64 > = vec ! [ 0 ; num_dims as usize ] ;
560+ let status = g_ort ( ) . GetDimensions . unwrap ( ) (
561+ tensor_info_ptr,
562+ node_dims. as_mut_ptr ( ) , // FIXME: UB?
563+ num_dims,
564+ ) ;
565+ status_to_result ( status) . map_err ( OrtError :: GetDimensions ) ?;
566+ Ok ( node_dims)
567+ }
568+
563569/// This module contains dangerous functions working on raw pointers.
564570/// Those functions are only to be used from inside the
565571/// `SessionBuilder::with_model_from_file()` method.
@@ -694,22 +700,7 @@ mod dangerous {
694700
695701 // info!("{} : type={}", i, type_);
696702
697- // print input shapes/dims
698- let mut num_dims = 0 ;
699- let status = unsafe { g_ort ( ) . GetDimensionsCount . unwrap ( ) ( tensor_info_ptr, & mut num_dims) } ;
700- status_to_result ( status) . map_err ( OrtError :: GetDimensionsCount ) ?;
701- assert_ne ! ( num_dims, 0 ) ;
702-
703- // info!("{} : num_dims={}", i, num_dims);
704- let mut node_dims: Vec < i64 > = vec ! [ 0 ; num_dims as usize ] ;
705- let status = unsafe {
706- g_ort ( ) . GetDimensions . unwrap ( ) (
707- tensor_info_ptr,
708- node_dims. as_mut_ptr ( ) , // FIXME: UB?
709- num_dims,
710- )
711- } ;
712- status_to_result ( status) . map_err ( OrtError :: GetDimensions ) ?;
703+ let node_dims = unsafe { get_tensor_dimensions ( tensor_info_ptr) ? } ;
713704
714705 // for j in 0..num_dims {
715706 // info!("{} : dim {}={}", i, j, node_dims[j as usize]);
0 commit comments