Skip to content

Commit 788bd3a

Browse files
committed
Return an Err on shape mismatch instead of panic
This way we let the user of the library decide how to handle the issue.
1 parent 3b5fcd3 commit 788bd3a

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

onnxruntime/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ pub enum NonMatchingDimensionsError {
119119
/// Input dimensions defined in model
120120
model_input: Vec<Vec<Option<u32>>>,
121121
},
122+
/// Inputs length from model does not match the expected input from inference call
123+
#[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
124+
InputsLength {
125+
/// Input dimensions used by inference call
126+
inference_input: Vec<Vec<usize>>,
127+
/// Input dimensions defined in model
128+
model_input: Vec<Vec<Option<u32>>>,
129+
},
122130
}
123131

124132
/// Error details when ONNX C API fail

onnxruntime/src/session.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,10 +520,19 @@ impl<'a> Session<'a> {
520520
"Different input lengths: {:?} vs {:?}",
521521
self.inputs, input_arrays
522522
);
523-
panic!(
524-
"Different input lengths: {:?} vs {:?}",
525-
self.inputs, input_arrays
526-
);
523+
return Err(OrtError::NonMatchingDimensions(
524+
NonMatchingDimensionsError::InputsLength {
525+
inference_input: input_arrays
526+
.iter()
527+
.map(|input_array| input_array.shape().to_vec())
528+
.collect(),
529+
model_input: self
530+
.inputs
531+
.iter()
532+
.map(|input| input.dimensions.clone())
533+
.collect(),
534+
},
535+
));
527536
}
528537

529538
// Verify shape of each individual inputs
@@ -540,10 +549,19 @@ impl<'a> Session<'a> {
540549
"Different input lengths: {:?} vs {:?}",
541550
self.inputs, input_arrays
542551
);
543-
panic!(
544-
"Different input lengths: {:?} vs {:?}",
545-
self.inputs, input_arrays
546-
);
552+
return Err(OrtError::NonMatchingDimensions(
553+
NonMatchingDimensionsError::InputsLength {
554+
inference_input: input_arrays
555+
.iter()
556+
.map(|input_array| input_array.shape().to_vec())
557+
.collect(),
558+
model_input: self
559+
.inputs
560+
.iter()
561+
.map(|input| input.dimensions.clone())
562+
.collect(),
563+
},
564+
));
547565
}
548566

549567
Ok(())

0 commit comments

Comments
 (0)