Skip to content

Commit d0330e6

Browse files
authored
Add get_[inputs|outputs]_len to CNNNetwork (#34)
This change is an adaptation of #33 that allows for the same kind of functionality. It also adds tests checking the newly-added functions. With this in place, users will be able to set input layouts with something like: ```rust for i in 0..network.get_inputs_len() { let name = network.get_input_name(i)?; network.set_input_layout(&name, Layout::NHWC)?; } ```
1 parent 917a28a commit d0330e6

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

crates/openvino/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use thiserror::Error;
33
/// Enumerate errors returned by the OpenVINO implementation. See
44
/// [IEStatusCode](https://docs.openvinotoolkit.org/latest/ie_c_api/ie__c__api_8h.html#a391683b1e8e26df8b58d7033edd9ee83).
55
/// TODO This could be auto-generated (https://github.com/intel/openvino-rs/issues/20).
6-
#[derive(Debug, Error)]
6+
#[derive(Debug, Error, PartialEq)]
77
pub enum InferenceError {
88
#[error("general error")]
99
GeneralError,

crates/openvino/src/network.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
//! Contains the network representations in OpenVINO:
2-
//! - [CNNNetwork] is the OpenVINO represenation of a neural network
2+
//! - [CNNNetwork] is the OpenVINO representation of a neural network
33
//! - [ExecutableNetwork] is the compiled representation of a [CNNNetwork] for a device.
44
55
use crate::request::InferRequest;
66
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
77
use crate::{Layout, Precision, ResizeAlgorithm};
88
use openvino_sys::{
99
ie_exec_network_create_infer_request, ie_exec_network_free, ie_executable_network_t,
10-
ie_network_free, ie_network_get_input_name, ie_network_get_output_name, ie_network_name_free,
10+
ie_network_free, ie_network_get_input_name, ie_network_get_inputs_number,
11+
ie_network_get_output_name, ie_network_get_outputs_number, ie_network_name_free,
1112
ie_network_set_input_layout, ie_network_set_input_precision,
1213
ie_network_set_input_resize_algorithm, ie_network_set_output_precision, ie_network_t,
1314
};
@@ -21,6 +22,20 @@ pub struct CNNNetwork {
2122
drop_using_function!(CNNNetwork, ie_network_free);
2223

2324
impl CNNNetwork {
25+
/// Retrieve the number of network inputs.
26+
pub fn get_inputs_len(&self) -> Result<usize> {
27+
let mut num: usize = 0;
28+
try_unsafe!(ie_network_get_inputs_number(self.instance, &mut num))?;
29+
Ok(num)
30+
}
31+
32+
/// Retrieve the number of network outputs.
33+
pub fn get_outputs_len(&self) -> Result<usize> {
34+
let mut num: usize = 0;
35+
try_unsafe!(ie_network_get_outputs_number(self.instance, &mut num))?;
36+
Ok(num)
37+
}
38+
2439
/// Retrieve the name identifying the input tensor at `index`.
2540
pub fn get_input_name(&self, index: usize) -> Result<String> {
2641
let mut cname = std::ptr::null_mut();

crates/openvino/tests/setup.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ fn read_network() {
1010
let mut core = Core::new(None).unwrap();
1111
let model = fs::read(Fixture::graph()).unwrap();
1212
let weights = fs::read(Fixture::weights()).unwrap();
13-
core.read_network_from_buffer(&model, &weights).unwrap();
13+
let network = core.read_network_from_buffer(&model, &weights).unwrap();
14+
15+
// Check the number of inputs and outputs.
16+
assert_eq!(network.get_inputs_len(), Ok(1));
17+
assert_eq!(network.get_outputs_len(), Ok(1));
1418
}

0 commit comments

Comments
 (0)