Skip to content

Commit 3b50b5b

Browse files
authored
Merge pull request #56 from ccbrown/dynamic-size-fix
fix dynamically sized outputs
2 parents 9bc0a49 + 7b28720 commit 3b50b5b

File tree

4 files changed

+126
-44
lines changed

4 files changed

+126
-44
lines changed

onnxruntime/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ pub enum OrtError {
6262
/// Error occurred when checking if ONNX tensor was properly initialized
6363
#[error("Failed to check if tensor: {0}")]
6464
IsTensor(OrtApiError),
65+
/// Error occurred when getting tensor type and shape
66+
#[error("Failed to get tensor type and shape: {0}")]
67+
GetTensorTypeAndShape(OrtApiError),
6568
/// Error occurred when ONNX inference operation was called
6669
#[error("Failed to run: {0}")]
6770
Run(OrtApiError),

onnxruntime/src/session.rs

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -405,32 +405,8 @@ impl<'a> Session<'a> {
405405
.map(|n| n.as_ptr() as *const i8)
406406
.collect();
407407

408-
let output_shapes: Vec<Vec<usize>> = {
409-
let mut tmp = Vec::new();
410-
for (idx, output) in self.outputs.iter().enumerate() {
411-
let v: Vec<_> = output
412-
.dimensions
413-
.iter()
414-
.enumerate()
415-
.map(|(jdx, dim)| match dim {
416-
None => input_arrays[idx].shape()[jdx],
417-
Some(d) => *d as usize,
418-
})
419-
.collect();
420-
tmp.push(v);
421-
}
422-
tmp
423-
};
424-
let memory_info_ref = &self.memory_info;
425-
let output_tensor_extractors: Vec<OrtOwnedTensorExtractor<ndarray::IxDyn>> = output_shapes
426-
.iter()
427-
.map(|output_shape| {
428-
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(output_shape))
429-
})
430-
.collect();
431-
432408
let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
433-
vec![std::ptr::null_mut(); output_tensor_extractors.len()];
409+
vec![std::ptr::null_mut(); self.outputs.len()];
434410

435411
// The C API expects pointers for the arrays (pointers to C-arrays)
436412
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
@@ -458,11 +434,23 @@ impl<'a> Session<'a> {
458434
};
459435
status_to_result(status).map_err(OrtError::Run)?;
460436

437+
let memory_info_ref = &self.memory_info;
461438
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
462-
output_tensor_extractors
439+
output_tensor_extractors_ptrs
463440
.into_iter()
464-
.zip(output_tensor_extractors_ptrs.into_iter())
465-
.map(|(mut output_tensor_extractor, ptr)| {
441+
.map(|ptr| {
442+
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
443+
std::ptr::null_mut();
444+
let status = unsafe {
445+
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
446+
};
447+
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
448+
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
449+
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
450+
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();
451+
452+
let mut output_tensor_extractor =
453+
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
466454
output_tensor_extractor.tensor_ptr = ptr;
467455
output_tensor_extractor.extract::<TOut>()
468456
})
@@ -560,6 +548,24 @@ impl<'a> Session<'a> {
560548
}
561549
}
562550

551+
unsafe fn get_tensor_dimensions(
552+
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
553+
) -> Result<Vec<i64>> {
554+
let mut num_dims = 0;
555+
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
556+
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
557+
assert_ne!(num_dims, 0);
558+
559+
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
560+
let status = g_ort().GetDimensions.unwrap()(
561+
tensor_info_ptr,
562+
node_dims.as_mut_ptr(), // FIXME: UB?
563+
num_dims,
564+
);
565+
status_to_result(status).map_err(OrtError::GetDimensions)?;
566+
Ok(node_dims)
567+
}
568+
563569
/// This module contains dangerous functions working on raw pointers.
564570
/// Those functions are only to be used from inside the
565571
/// `SessionBuilder::with_model_from_file()` method.
@@ -694,22 +700,7 @@ mod dangerous {
694700

695701
// info!("{} : type={}", i, type_);
696702

697-
// print input shapes/dims
698-
let mut num_dims = 0;
699-
let status = unsafe { g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims) };
700-
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
701-
assert_ne!(num_dims, 0);
702-
703-
// info!("{} : num_dims={}", i, num_dims);
704-
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
705-
let status = unsafe {
706-
g_ort().GetDimensions.unwrap()(
707-
tensor_info_ptr,
708-
node_dims.as_mut_ptr(), // FIXME: UB?
709-
num_dims,
710-
)
711-
};
712-
status_to_result(status).map_err(OrtError::GetDimensions)?;
703+
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
713704

714705
// for j in 0..num_dims {
715706
// info!("{} : dim {}={}", i, j, node_dims[j as usize]);
1.82 KB
Binary file not shown.

onnxruntime/tests/integration_tests.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,94 @@ mod download {
204204
IMAGE_TO_LOAD, probabilities[0].0
205205
);
206206
}
207+
208+
// This test verifies that dynamically sized inputs and outputs work. It loads and runs
209+
// upsample.onnx, which was produced via:
210+
//
211+
// ```
212+
// import subprocess
213+
// from tensorflow import keras
214+
//
215+
// m = keras.Sequential([
216+
// keras.layers.UpSampling2D(size=2)
217+
// ])
218+
// m.build(input_shape=(None, None, None, 3))
219+
// m.summary()
220+
// m.save('saved_model')
221+
//
222+
// subprocess.check_call([
223+
// 'python', '-m', 'tf2onnx.convert',
224+
// '--saved-model', 'saved_model',
225+
// '--opset', '12',
226+
// '--output', 'upsample.onnx',
227+
// ])
228+
// ```
229+
#[test]
230+
fn upsample() {
231+
const IMAGE_TO_LOAD: &str = "mushroom.png";
232+
233+
let environment = Environment::builder()
234+
.with_name("integration_test")
235+
.with_log_level(LoggingLevel::Warning)
236+
.build()
237+
.unwrap();
238+
239+
let mut session = environment
240+
.new_session_builder()
241+
.unwrap()
242+
.with_optimization_level(GraphOptimizationLevel::Basic)
243+
.unwrap()
244+
.with_number_threads(1)
245+
.unwrap()
246+
.with_model_from_file(
247+
Path::new(env!("CARGO_MANIFEST_DIR"))
248+
.join("tests")
249+
.join("data")
250+
.join("upsample.onnx"),
251+
)
252+
.expect("Could not open model from file");
253+
254+
assert_eq!(
255+
session.inputs[0].dimensions().collect::<Vec<_>>(),
256+
[None, None, None, Some(3)]
257+
);
258+
assert_eq!(
259+
session.outputs[0].dimensions().collect::<Vec<_>>(),
260+
[None, None, None, Some(3)]
261+
);
262+
263+
// Load image, converting to RGB format
264+
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(
265+
Path::new(env!("CARGO_MANIFEST_DIR"))
266+
.join("tests")
267+
.join("data")
268+
.join(IMAGE_TO_LOAD),
269+
)
270+
.unwrap()
271+
.to_rgb();
272+
273+
let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| {
274+
let pixel = image_buffer.get_pixel(i as u32, j as u32);
275+
let channels = pixel.channels();
276+
277+
// range [0, 255] -> range [0, 1]
278+
(channels[c] as f32) / 255.0
279+
});
280+
281+
// Just one input
282+
let input_tensor_values = vec![array];
283+
284+
// Perform the inference
285+
let outputs: Vec<
286+
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
287+
> = session.run(input_tensor_values).unwrap();
288+
289+
assert_eq!(outputs.len(), 1);
290+
let output = &outputs[0];
291+
292+
// The image should have doubled in size
293+
assert_eq!(output.shape(), [1, 448, 448, 3]);
294+
}
207295
}
208296

209297
fn get_imagenet_labels() -> Result<Vec<String>, io::Error> {

0 commit comments

Comments
 (0)