Skip to content

Commit b3e1a26

Browse files
committed
More consistent error handling
Error handling should be left up to the user of the library in general. Most of this change (aside from some integration test compile problems), is removing panics. In asserting a null pointer or not, we return a Result instead of panicking. In drops, instead of asserting a non null pointer, we just check `is_null` on the object before running the C deallocation function for the pointer and log an error if its null when we're trying to deallocate.
1 parent 8ee61fb commit b3e1a26

File tree

7 files changed

+105
-51
lines changed

7 files changed

+105
-51
lines changed

onnxruntime/src/environment.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
};
77

88
use lazy_static::lazy_static;
9-
use tracing::{debug, warn};
9+
use tracing::{debug, error, warn};
1010

1111
use onnxruntime_sys as sys;
1212

@@ -182,7 +182,11 @@ impl Drop for Environment {
182182
);
183183

184184
assert_ne!(env_ptr, std::ptr::null_mut());
185-
unsafe { release_env(env_ptr) };
185+
if env_ptr.is_null() {
186+
error!("Environment pointer is null, not dropping!");
187+
} else {
188+
unsafe { release_env(env_ptr) };
189+
}
186190

187191
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
188192
environment_guard.name = String::from("uninitialized");

onnxruntime/src/error.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ pub enum OrtError {
100100
/// Attempt to build a Rust `CString` from a null pointer
101101
#[error("Failed to build CString when original contains null: {0}")]
102102
CStringNulError(#[from] std::ffi::NulError),
103+
#[error("{0} pointer should be null")]
104+
/// Ort Pointer should have been null
105+
PointerShouldBeNull(String),
106+
/// Ort pointer should not have been null
107+
#[error("{0} pointer should not be null")]
108+
PointerShouldNotBeNull(String),
109+
/// ONNX Model has invalid dimensions
110+
#[error("Invalid dimensions")]
111+
InvalidDimensions,
112+
/// The runtime type was undefined
113+
#[error("Undefined Tensor Element Type")]
114+
UndefinedTensorElementType,
115+
/// Error occurred when checking if ONNX tensor was properly initialized
116+
#[error("Failed to check if tensor")]
117+
IsTensorCheck,
103118
}
104119

105120
/// Error used when dimensions of input (from model and from inference call)
@@ -176,6 +191,18 @@ impl From<*const sys::OrtStatus> for OrtStatusWrapper {
176191
}
177192
}
178193

194+
pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
195+
ptr.is_null()
196+
.then(|| ())
197+
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
198+
}
199+
200+
pub(crate) fn assert_not_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
201+
(!ptr.is_null())
202+
.then(|| ())
203+
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
204+
}
205+
179206
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
180207
fn from(status: OrtStatusWrapper) -> Self {
181208
if status.0.is_null() {

onnxruntime/src/memory.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ use tracing::debug;
22

33
use onnxruntime_sys as sys;
44

5+
use tracing::error;
6+
57
use crate::{
6-
error::{status_to_result, OrtError, Result},
8+
error::{assert_not_null_pointer, status_to_result, OrtError, Result},
79
g_ort, AllocatorType, MemType,
810
};
911

@@ -25,7 +27,7 @@ impl MemoryInfo {
2527
)
2628
};
2729
status_to_result(status).map_err(OrtError::CreateCpuMemoryInfo)?;
28-
assert_ne!(memory_info_ptr, std::ptr::null_mut());
30+
assert_not_null_pointer(memory_info_ptr, "MemoryInfo")?;
2931

3032
Ok(Self {
3133
ptr: memory_info_ptr,
@@ -36,10 +38,12 @@ impl MemoryInfo {
3638
impl Drop for MemoryInfo {
3739
#[tracing::instrument]
3840
fn drop(&mut self) {
39-
debug!("Dropping the memory information.");
40-
assert_ne!(self.ptr, std::ptr::null_mut());
41-
42-
unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
41+
if self.ptr.is_null() {
42+
error!("MemoryInfo pointer is null, not dropping.");
43+
} else {
44+
debug!("Dropping the memory information.");
45+
unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
46+
}
4347

4448
self.ptr = std::ptr::null_mut();
4549
}

onnxruntime/src/session.rs

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
1818
use crate::{
1919
char_p_to_string,
2020
environment::Environment,
21-
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
21+
error::{
22+
assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError,
23+
OrtApiError, OrtError, Result,
24+
},
2225
g_ort,
2326
memory::MemoryInfo,
2427
tensor::{
@@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
7376
impl<'a> Drop for SessionBuilder<'a> {
7477
#[tracing::instrument]
7578
fn drop(&mut self) {
76-
debug!("Dropping the session options.");
77-
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
78-
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
79+
if self.session_options_ptr.is_null() {
80+
error!("Session options pointer is null, not dropping");
81+
} else {
82+
debug!("Dropping the session options.");
83+
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
84+
}
7985
}
8086
}
8187

@@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
8591
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };
8692

8793
status_to_result(status).map_err(OrtError::SessionOptions)?;
88-
assert_eq!(status, std::ptr::null_mut());
89-
assert_ne!(session_options_ptr, std::ptr::null_mut());
94+
assert_null_pointer(status, "SessionStatus")?;
95+
assert_not_null_pointer(session_options_ptr, "SessionOptions")?;
9096

9197
Ok(SessionBuilder {
9298
env,
@@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
105111
let status =
106112
unsafe { g_ort().SetIntraOpNumThreads.unwrap()(self.session_options_ptr, num_threads) };
107113
status_to_result(status).map_err(OrtError::SessionOptions)?;
108-
assert_eq!(status, std::ptr::null_mut());
114+
assert_null_pointer(status, "SessionStatus")?;
109115
Ok(self)
110116
}
111117

@@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
199205
)
200206
};
201207
status_to_result(status).map_err(OrtError::Session)?;
202-
assert_eq!(status, std::ptr::null_mut());
203-
assert_ne!(session_ptr, std::ptr::null_mut());
208+
assert_null_pointer(status, "SessionStatus")?;
209+
assert_not_null_pointer(session_ptr, "Session")?;
204210

205211
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
206212
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
207213
status_to_result(status).map_err(OrtError::Allocator)?;
208-
assert_eq!(status, std::ptr::null_mut());
209-
assert_ne!(allocator_ptr, std::ptr::null_mut());
214+
assert_null_pointer(status, "SessionStatus")?;
215+
assert_not_null_pointer(allocator_ptr, "Allocator")?;
210216

211217
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
212218

@@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
255261
)
256262
};
257263
status_to_result(status).map_err(OrtError::Session)?;
258-
assert_eq!(status, std::ptr::null_mut());
259-
assert_ne!(session_ptr, std::ptr::null_mut());
264+
assert_null_pointer(status, "SessionStatus")?;
265+
assert_not_null_pointer(session_ptr, "Session")?;
260266

261267
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
262268
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
263269
status_to_result(status).map_err(OrtError::Allocator)?;
264-
assert_eq!(status, std::ptr::null_mut());
265-
assert_ne!(allocator_ptr, std::ptr::null_mut());
270+
assert_null_pointer(status, "SessionStatus")?;
271+
assert_not_null_pointer(allocator_ptr, "Allocator")?;
266272

267273
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
268274

@@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
352358
#[tracing::instrument]
353359
fn drop(&mut self) {
354360
debug!("Dropping the session.");
355-
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
361+
if self.session_ptr.is_null() {
362+
error!("Session pointer is null, not dropping.");
363+
} else {
364+
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
365+
}
356366
// FIXME: There is no C function to release the allocator?
357367

358368
self.session_ptr = std::ptr::null_mut();
@@ -459,13 +469,14 @@ impl<'a> Session<'a> {
459469
.collect();
460470

461471
// Reconvert to CString so drop impl is called and memory is freed
462-
let _: Vec<CString> = input_names_ptr
472+
let cstrings: Result<Vec<CString>> = input_names_ptr
463473
.into_iter()
464474
.map(|p| {
465-
assert_ne!(p, std::ptr::null());
466-
unsafe { CString::from_raw(p as *mut i8) }
475+
assert_not_null_pointer(p, "i8 for CString")?;
476+
unsafe { Ok(CString::from_raw(p as *mut i8)) }
467477
})
468478
.collect();
479+
cstrings?;
469480

470481
outputs
471482
}
@@ -574,7 +585,9 @@ unsafe fn get_tensor_dimensions(
574585
let mut num_dims = 0;
575586
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
576587
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
577-
assert_ne!(num_dims, 0);
588+
(num_dims != 0)
589+
.then(|| ())
590+
.ok_or(OrtError::InvalidDimensions)?;
578591

579592
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
580593
let status = g_ort().GetDimensions.unwrap()(
@@ -609,8 +622,10 @@ mod dangerous {
609622
let mut num_nodes: usize = 0;
610623
let status = unsafe { f(session_ptr, &mut num_nodes) };
611624
status_to_result(status).map_err(OrtError::InOutCount)?;
612-
assert_eq!(status, std::ptr::null_mut());
613-
assert_ne!(num_nodes, 0);
625+
assert_null_pointer(status, "SessionStatus")?;
626+
(num_nodes != 0).then(|| ()).ok_or_else(|| {
627+
OrtError::InOutCount(OrtApiError::Msg("No nodes in model".to_owned()))
628+
})?;
614629
Ok(num_nodes)
615630
}
616631

@@ -647,7 +662,7 @@ mod dangerous {
647662

648663
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
649664
status_to_result(status).map_err(OrtError::InputName)?;
650-
assert_ne!(name_bytes, std::ptr::null_mut());
665+
assert_not_null_pointer(name_bytes, "InputName")?;
651666

652667
// FIXME: Is it safe to keep ownership of the memory?
653668
let name = char_p_to_string(name_bytes)?;
@@ -698,23 +713,22 @@ mod dangerous {
698713

699714
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
700715
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
701-
assert_ne!(typeinfo_ptr, std::ptr::null_mut());
716+
assert_not_null_pointer(typeinfo_ptr, "TypeInfo")?;
702717

703718
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
704719
let status = unsafe {
705720
g_ort().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr)
706721
};
707722
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
708-
assert_ne!(tensor_info_ptr, std::ptr::null_mut());
723+
assert_not_null_pointer(tensor_info_ptr, "TensorInfo")?;
709724

710725
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
711726
let status =
712727
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
713728
status_to_result(status).map_err(OrtError::TensorElementType)?;
714-
assert_ne!(
715-
type_sys,
716-
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
717-
);
729+
(type_sys != sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
730+
.then(|| ())
731+
.ok_or(OrtError::UndefinedTensorElementType)?;
718732
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
719733
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
720734

onnxruntime/src/tensor/ort_owned_tensor.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ where
9595
let mut is_tensor = 0;
9696
let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
9797
status_to_result(status).map_err(OrtError::IsTensor)?;
98-
assert_eq!(is_tensor, 1);
98+
(is_tensor == 1)
99+
.then(|| ())
100+
.ok_or(OrtError::IsTensorCheck)?;
99101

100102
// Get pointer to output tensor float values
101103
let mut output_array_ptr: *mut T = std::ptr::null_mut();

onnxruntime/src/tensor/ort_tensor.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ use tracing::{debug, error};
88
use onnxruntime_sys as sys;
99

1010
use crate::{
11-
error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo,
12-
tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType,
13-
TypeToTensorElementDataType,
11+
error::{assert_not_null_pointer, call_ort, status_to_result},
12+
g_ort,
13+
memory::MemoryInfo,
14+
tensor::ndarray_tensor::NdArrayTensor,
15+
OrtError, Result, TensorElementDataType, TypeToTensorElementDataType,
1416
};
1517

1618
/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
@@ -67,7 +69,7 @@ where
6769
// onnxruntime as is
6870
let tensor_values_ptr: *mut std::ffi::c_void =
6971
array.as_mut_ptr() as *mut std::ffi::c_void;
70-
assert_ne!(tensor_values_ptr, std::ptr::null_mut());
72+
assert_not_null_pointer(tensor_values_ptr, "TensorValues")?;
7173

7274
unsafe {
7375
call_ort(|ort| {
@@ -83,7 +85,7 @@ where
8385
})
8486
}
8587
.map_err(OrtError::CreateTensorWithData)?;
86-
assert_ne!(tensor_ptr, std::ptr::null_mut());
88+
assert_not_null_pointer(tensor_ptr, "Tensor")?;
8789

8890
let mut is_tensor = 0;
8991
let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
@@ -134,7 +136,7 @@ where
134136
}
135137
}
136138

137-
assert_ne!(tensor_ptr, std::ptr::null_mut());
139+
assert_not_null_pointer(tensor_ptr, "Tensor")?;
138140

139141
Ok(OrtTensor {
140142
c_ptr: tensor_ptr,

onnxruntime/tests/integration_tests.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66
};
77

88
use onnxruntime::error::OrtDownloadError;
9+
use onnxruntime::tensor::OrtOwnedTensor;
910

1011
mod download {
1112
use super::*;
@@ -64,7 +65,7 @@ mod download {
6465
input0_shape[3] as u32,
6566
FilterType::Nearest,
6667
)
67-
.to_rgb();
68+
.to_rgb8();
6869

6970
// Python:
7071
// # image[y, x, RGB]
@@ -101,10 +102,10 @@ mod download {
101102

102103
// Downloaded model does not have a softmax as final layer; call softmax on second axis
103104
// and iterate on resulting probabilities, creating an index to later access labels.
104-
let mut probabilities: Vec<(usize, f32)> = outputs[0]
105+
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
106+
let mut probabilities: Vec<(usize, f32)> = output
105107
.softmax(ndarray::Axis(1))
106108
.into_iter()
107-
.copied()
108109
.enumerate()
109110
.collect::<Vec<_>>();
110111
// Sort probabilities so highest is at beginning of vector.
@@ -172,7 +173,7 @@ mod download {
172173
input0_shape[3] as u32,
173174
FilterType::Nearest,
174175
)
175-
.to_luma();
176+
.to_luma8();
176177

177178
let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| {
178179
let pixel = image_buffer.get_pixel(i as u32, j as u32);
@@ -190,10 +191,10 @@ mod download {
190191
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
191192
> = session.run(input_tensor_values).unwrap();
192193

193-
let mut probabilities: Vec<(usize, f32)> = outputs[0]
194+
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
195+
let mut probabilities: Vec<(usize, f32)> = output
194196
.softmax(ndarray::Axis(1))
195197
.into_iter()
196-
.copied()
197198
.enumerate()
198199
.collect::<Vec<_>>();
199200

@@ -270,7 +271,7 @@ mod download {
270271
.join(IMAGE_TO_LOAD),
271272
)
272273
.unwrap()
273-
.to_rgb();
274+
.to_rgb8();
274275

275276
let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| {
276277
let pixel = image_buffer.get_pixel(i as u32, j as u32);

0 commit comments

Comments
 (0)