@@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
1818use crate :: {
1919 char_p_to_string,
2020 environment:: Environment ,
21- error:: { status_to_result, NonMatchingDimensionsError , OrtError , Result } ,
21+ error:: {
22+ assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError ,
23+ OrtApiError , OrtError , Result ,
24+ } ,
2225 g_ort,
2326 memory:: MemoryInfo ,
2427 tensor:: {
@@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
7376impl < ' a > Drop for SessionBuilder < ' a > {
7477 #[ tracing:: instrument]
7578 fn drop ( & mut self ) {
76- debug ! ( "Dropping the session options." ) ;
77- assert_ne ! ( self . session_options_ptr, std:: ptr:: null_mut( ) ) ;
78- unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
79+ if self . session_options_ptr . is_null ( ) {
80+ error ! ( "Session options pointer is null, not dropping" ) ;
81+ } else {
82+ debug ! ( "Dropping the session options." ) ;
83+ unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
84+ }
7985 }
8086}
8187
@@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
8591 let status = unsafe { g_ort ( ) . CreateSessionOptions . unwrap ( ) ( & mut session_options_ptr) } ;
8692
8793 status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
88- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
89- assert_ne ! ( session_options_ptr, std :: ptr :: null_mut ( ) ) ;
94+ assert_null_pointer ( status, "SessionStatus" ) ? ;
95+ assert_not_null_pointer ( session_options_ptr, "SessionOptions" ) ? ;
9096
9197 Ok ( SessionBuilder {
9298 env,
@@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
105111 let status =
106112 unsafe { g_ort ( ) . SetIntraOpNumThreads . unwrap ( ) ( self . session_options_ptr , num_threads) } ;
107113 status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
108- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
114+ assert_null_pointer ( status, "SessionStatus" ) ? ;
109115 Ok ( self )
110116 }
111117
@@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
199205 )
200206 } ;
201207 status_to_result ( status) . map_err ( OrtError :: Session ) ?;
202- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
203- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
208+ assert_null_pointer ( status, "SessionStatus" ) ? ;
209+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
204210
205211 let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
206212 let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
207213 status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
208- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
209- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
214+ assert_null_pointer ( status, "SessionStatus" ) ? ;
215+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
210216
211217 let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
212218
@@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
255261 )
256262 } ;
257263 status_to_result ( status) . map_err ( OrtError :: Session ) ?;
258- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
259- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
264+ assert_null_pointer ( status, "SessionStatus" ) ? ;
265+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
260266
261267 let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
262268 let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
263269 status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
264- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
265- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
270+ assert_null_pointer ( status, "SessionStatus" ) ? ;
271+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
266272
267273 let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
268274
@@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
352358 #[ tracing:: instrument]
353359 fn drop ( & mut self ) {
354360 debug ! ( "Dropping the session." ) ;
355- unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
361+ if self . session_ptr . is_null ( ) {
362+ error ! ( "Session pointer is null, not dropping." ) ;
363+ } else {
364+ unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
365+ }
356366 // FIXME: There is no C function to release the allocator?
357367
358368 self . session_ptr = std:: ptr:: null_mut ( ) ;
@@ -453,13 +463,14 @@ impl<'a> Session<'a> {
453463 . collect ( ) ;
454464
455465 // Reconvert to CString so drop impl is called and memory is freed
456- let _ : Vec < CString > = input_names_ptr
466+ let cstrings : Result < Vec < CString > > = input_names_ptr
457467 . into_iter ( )
458468 . map ( |p| {
459- assert_ne ! ( p, std :: ptr :: null ( ) ) ;
460- unsafe { CString :: from_raw ( p as * mut i8 ) }
469+ assert_not_null_pointer ( p, "i8 for CString" ) ? ;
470+ unsafe { Ok ( CString :: from_raw ( p as * mut i8 ) ) }
461471 } )
462472 . collect ( ) ;
473+ cstrings?;
463474
464475 outputs
465476 }
@@ -568,7 +579,9 @@ unsafe fn get_tensor_dimensions(
568579 let mut num_dims = 0 ;
569580 let status = g_ort ( ) . GetDimensionsCount . unwrap ( ) ( tensor_info_ptr, & mut num_dims) ;
570581 status_to_result ( status) . map_err ( OrtError :: GetDimensionsCount ) ?;
571- assert_ne ! ( num_dims, 0 ) ;
582+ ( num_dims != 0 )
583+ . then ( || ( ) )
584+ . ok_or ( OrtError :: InvalidDimensions ) ?;
572585
573586 let mut node_dims: Vec < i64 > = vec ! [ 0 ; num_dims as usize ] ;
574587 let status = g_ort ( ) . GetDimensions . unwrap ( ) (
@@ -603,8 +616,10 @@ mod dangerous {
603616 let mut num_nodes: usize = 0 ;
604617 let status = unsafe { f ( session_ptr, & mut num_nodes) } ;
605618 status_to_result ( status) . map_err ( OrtError :: InOutCount ) ?;
606- assert_eq ! ( status, std:: ptr:: null_mut( ) ) ;
607- assert_ne ! ( num_nodes, 0 ) ;
619+ assert_null_pointer ( status, "SessionStatus" ) ?;
620+ ( num_nodes != 0 ) . then ( || ( ) ) . ok_or_else ( || {
621+ OrtError :: InOutCount ( OrtApiError :: Msg ( "No nodes in model" . to_owned ( ) ) )
622+ } ) ?;
608623 Ok ( num_nodes)
609624 }
610625
@@ -641,7 +656,7 @@ mod dangerous {
641656
642657 let status = unsafe { f ( session_ptr, i, allocator_ptr, & mut name_bytes) } ;
643658 status_to_result ( status) . map_err ( OrtError :: InputName ) ?;
644- assert_ne ! ( name_bytes, std :: ptr :: null_mut ( ) ) ;
659+ assert_not_null_pointer ( name_bytes, "InputName" ) ? ;
645660
646661 // FIXME: Is it safe to keep ownership of the memory?
647662 let name = char_p_to_string ( name_bytes) ?;
@@ -692,23 +707,22 @@ mod dangerous {
692707
693708 let status = unsafe { f ( session_ptr, i, & mut typeinfo_ptr) } ;
694709 status_to_result ( status) . map_err ( OrtError :: GetTypeInfo ) ?;
695- assert_ne ! ( typeinfo_ptr, std :: ptr :: null_mut ( ) ) ;
710+ assert_not_null_pointer ( typeinfo_ptr, "TypeInfo" ) ? ;
696711
697712 let mut tensor_info_ptr: * const sys:: OrtTensorTypeAndShapeInfo = std:: ptr:: null_mut ( ) ;
698713 let status = unsafe {
699714 g_ort ( ) . CastTypeInfoToTensorInfo . unwrap ( ) ( typeinfo_ptr, & mut tensor_info_ptr)
700715 } ;
701716 status_to_result ( status) . map_err ( OrtError :: CastTypeInfoToTensorInfo ) ?;
702- assert_ne ! ( tensor_info_ptr, std :: ptr :: null_mut ( ) ) ;
717+ assert_not_null_pointer ( tensor_info_ptr, "TensorInfo" ) ? ;
703718
704719 let mut type_sys = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ;
705720 let status =
706721 unsafe { g_ort ( ) . GetTensorElementType . unwrap ( ) ( tensor_info_ptr, & mut type_sys) } ;
707722 status_to_result ( status) . map_err ( OrtError :: TensorElementType ) ?;
708- assert_ne ! (
709- type_sys,
710- sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
711- ) ;
723+ ( type_sys != sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED )
724+ . then ( || ( ) )
725+ . ok_or ( OrtError :: UndefinedTensorElementType ) ?;
712726 // This transmute should be safe since its value is read from GetTensorElementType which we must trust.
713727 let io_type: TensorElementDataType = unsafe { std:: mem:: transmute ( type_sys) } ;
714728
0 commit comments