Skip to content

Commit 8ee61fb

Browse files
authored
Merge pull request #77 from WallarooLabs/fix_shape_panic
Return an Err on shape mismatch instead of panic
2 parents 445abca + 788bd3a commit 8ee61fb

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

onnxruntime/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ pub enum NonMatchingDimensionsError {
119119
/// Input dimensions defined in model
120120
model_input: Vec<Vec<Option<u32>>>,
121121
},
122+
/// Inputs length from model does not match the expected input from inference call
123+
#[error("Different input lengths: Expected Input: {model_input:?} vs Received Input: {inference_input:?}")]
124+
InputsLength {
125+
/// Input dimensions used by inference call
126+
inference_input: Vec<Vec<usize>>,
127+
/// Input dimensions defined in model
128+
model_input: Vec<Vec<Option<u32>>>,
129+
},
122130
}
123131

124132
/// Error details when ONNX C API fail

onnxruntime/src/session.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,10 +520,19 @@ impl<'a> Session<'a> {
520520
"Different input lengths: {:?} vs {:?}",
521521
self.inputs, input_arrays
522522
);
523-
panic!(
524-
"Different input lengths: {:?} vs {:?}",
525-
self.inputs, input_arrays
526-
);
523+
return Err(OrtError::NonMatchingDimensions(
524+
NonMatchingDimensionsError::InputsLength {
525+
inference_input: input_arrays
526+
.iter()
527+
.map(|input_array| input_array.shape().to_vec())
528+
.collect(),
529+
model_input: self
530+
.inputs
531+
.iter()
532+
.map(|input| input.dimensions.clone())
533+
.collect(),
534+
},
535+
));
527536
}
528537

529538
// Verify shape of each individual inputs
@@ -540,10 +549,19 @@ impl<'a> Session<'a> {
540549
"Different input lengths: {:?} vs {:?}",
541550
self.inputs, input_arrays
542551
);
543-
panic!(
544-
"Different input lengths: {:?} vs {:?}",
545-
self.inputs, input_arrays
546-
);
552+
return Err(OrtError::NonMatchingDimensions(
553+
NonMatchingDimensionsError::InputsLength {
554+
inference_input: input_arrays
555+
.iter()
556+
.map(|input_array| input_array.shape().to_vec())
557+
.collect(),
558+
model_input: self
559+
.inputs
560+
.iter()
561+
.map(|input| input.dimensions.clone())
562+
.collect(),
563+
},
564+
));
547565
}
548566

549567
Ok(())

0 commit comments

Comments
 (0)