@@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
1818use crate :: {
1919 char_p_to_string,
2020 environment:: Environment ,
21- error:: { status_to_result, NonMatchingDimensionsError , OrtError , Result } ,
21+ error:: {
22+ assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError ,
23+ OrtApiError , OrtError , Result ,
24+ } ,
2225 g_ort,
2326 memory:: MemoryInfo ,
2427 tensor:: {
@@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
7376impl < ' a > Drop for SessionBuilder < ' a > {
7477 #[ tracing:: instrument]
7578 fn drop ( & mut self ) {
76- debug ! ( "Dropping the session options." ) ;
77- assert_ne ! ( self . session_options_ptr, std:: ptr:: null_mut( ) ) ;
78- unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
79+ if self . session_options_ptr . is_null ( ) {
80+ error ! ( "Session options pointer is null, not dropping" ) ;
81+ } else {
82+ debug ! ( "Dropping the session options." ) ;
83+ unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
84+ }
7985 }
8086}
8187
@@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
8591 let status = unsafe { g_ort ( ) . CreateSessionOptions . unwrap ( ) ( & mut session_options_ptr) } ;
8692
8793 status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
88- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
89- assert_ne ! ( session_options_ptr, std :: ptr :: null_mut ( ) ) ;
94+ assert_null_pointer ( status, "SessionStatus" ) ? ;
95+ assert_not_null_pointer ( session_options_ptr, "SessionOptions" ) ? ;
9096
9197 Ok ( SessionBuilder {
9298 env,
@@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
105111 let status =
106112 unsafe { g_ort ( ) . SetIntraOpNumThreads . unwrap ( ) ( self . session_options_ptr , num_threads) } ;
107113 status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
108- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
114+ assert_null_pointer ( status, "SessionStatus" ) ? ;
109115 Ok ( self )
110116 }
111117
@@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
199205 )
200206 } ;
201207 status_to_result ( status) . map_err ( OrtError :: Session ) ?;
202- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
203- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
208+ assert_null_pointer ( status, "SessionStatus" ) ? ;
209+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
204210
205211 let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
206212 let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
207213 status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
208- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
209- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
214+ assert_null_pointer ( status, "SessionStatus" ) ? ;
215+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
210216
211217 let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
212218
@@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
255261 )
256262 } ;
257263 status_to_result ( status) . map_err ( OrtError :: Session ) ?;
258- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
259- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
264+ assert_null_pointer ( status, "SessionStatus" ) ? ;
265+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
260266
261267 let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
262268 let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
263269 status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
264- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
265- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
270+ assert_null_pointer ( status, "SessionStatus" ) ? ;
271+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
266272
267273 let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
268274
@@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
352358 #[ tracing:: instrument]
353359 fn drop ( & mut self ) {
354360 debug ! ( "Dropping the session." ) ;
355- unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
361+ if self . session_ptr . is_null ( ) {
362+ error ! ( "Session pointer is null, not dropping." ) ;
363+ } else {
364+ unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
365+ }
356366 // FIXME: There is no C function to release the allocator?
357367
358368 self . session_ptr = std:: ptr:: null_mut ( ) ;
@@ -459,13 +469,14 @@ impl<'a> Session<'a> {
459469 . collect ( ) ;
460470
461471 // Reconvert to CString so drop impl is called and memory is freed
462- let _ : Vec < CString > = input_names_ptr
472+ let cstrings : Result < Vec < CString > > = input_names_ptr
463473 . into_iter ( )
464474 . map ( |p| {
465- assert_ne ! ( p, std :: ptr :: null ( ) ) ;
466- unsafe { CString :: from_raw ( p as * mut i8 ) }
475+ assert_not_null_pointer ( p, "i8 for CString" ) ? ;
476+ unsafe { Ok ( CString :: from_raw ( p as * mut i8 ) ) }
467477 } )
468478 . collect ( ) ;
479+ cstrings?;
469480
470481 outputs
471482 }
@@ -574,7 +585,9 @@ unsafe fn get_tensor_dimensions(
574585 let mut num_dims = 0 ;
575586 let status = g_ort ( ) . GetDimensionsCount . unwrap ( ) ( tensor_info_ptr, & mut num_dims) ;
576587 status_to_result ( status) . map_err ( OrtError :: GetDimensionsCount ) ?;
577- assert_ne ! ( num_dims, 0 ) ;
588+ ( num_dims != 0 )
589+ . then ( || ( ) )
590+ . ok_or ( OrtError :: InvalidDimensions ) ?;
578591
579592 let mut node_dims: Vec < i64 > = vec ! [ 0 ; num_dims as usize ] ;
580593 let status = g_ort ( ) . GetDimensions . unwrap ( ) (
@@ -609,8 +622,10 @@ mod dangerous {
609622 let mut num_nodes: usize = 0 ;
610623 let status = unsafe { f ( session_ptr, & mut num_nodes) } ;
611624 status_to_result ( status) . map_err ( OrtError :: InOutCount ) ?;
612- assert_eq ! ( status, std:: ptr:: null_mut( ) ) ;
613- assert_ne ! ( num_nodes, 0 ) ;
625+ assert_null_pointer ( status, "SessionStatus" ) ?;
626+ ( num_nodes != 0 ) . then ( || ( ) ) . ok_or_else ( || {
627+ OrtError :: InOutCount ( OrtApiError :: Msg ( "No nodes in model" . to_owned ( ) ) )
628+ } ) ?;
614629 Ok ( num_nodes)
615630 }
616631
@@ -647,7 +662,7 @@ mod dangerous {
647662
648663 let status = unsafe { f ( session_ptr, i, allocator_ptr, & mut name_bytes) } ;
649664 status_to_result ( status) . map_err ( OrtError :: InputName ) ?;
650- assert_ne ! ( name_bytes, std :: ptr :: null_mut ( ) ) ;
665+ assert_not_null_pointer ( name_bytes, "InputName" ) ? ;
651666
652667 // FIXME: Is it safe to keep ownership of the memory?
653668 let name = char_p_to_string ( name_bytes) ?;
@@ -698,23 +713,22 @@ mod dangerous {
698713
699714 let status = unsafe { f ( session_ptr, i, & mut typeinfo_ptr) } ;
700715 status_to_result ( status) . map_err ( OrtError :: GetTypeInfo ) ?;
701- assert_ne ! ( typeinfo_ptr, std :: ptr :: null_mut ( ) ) ;
716+ assert_not_null_pointer ( typeinfo_ptr, "TypeInfo" ) ? ;
702717
703718 let mut tensor_info_ptr: * const sys:: OrtTensorTypeAndShapeInfo = std:: ptr:: null_mut ( ) ;
704719 let status = unsafe {
705720 g_ort ( ) . CastTypeInfoToTensorInfo . unwrap ( ) ( typeinfo_ptr, & mut tensor_info_ptr)
706721 } ;
707722 status_to_result ( status) . map_err ( OrtError :: CastTypeInfoToTensorInfo ) ?;
708- assert_ne ! ( tensor_info_ptr, std :: ptr :: null_mut ( ) ) ;
723+ assert_not_null_pointer ( tensor_info_ptr, "TensorInfo" ) ? ;
709724
710725 let mut type_sys = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ;
711726 let status =
712727 unsafe { g_ort ( ) . GetTensorElementType . unwrap ( ) ( tensor_info_ptr, & mut type_sys) } ;
713728 status_to_result ( status) . map_err ( OrtError :: TensorElementType ) ?;
714- assert_ne ! (
715- type_sys,
716- sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
717- ) ;
729+ ( type_sys != sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED )
730+ . then ( || ( ) )
731+ . ok_or ( OrtError :: UndefinedTensorElementType ) ?;
718732 // This transmute should be safe since its value is read from GetTensorElementType which we must trust.
719733 let io_type: TensorElementDataType = unsafe { std:: mem:: transmute ( type_sys) } ;
720734
0 commit comments