Skip to content

Commit 0469b3c

Browse files
Add nchw conversion option (#173)
* add nchw conversion option * fix test, cargo fmt
1 parent 2f6b500 commit 0469b3c

File tree

3 files changed

+47
-6
lines changed

3 files changed

+47
-6
lines changed

crates/openvino-tensor-converter/src/lib.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@ use opencv::core::{MatTraitConst, Scalar_};
1414
use std::convert::TryInto;
1515
use std::{num::ParseIntError, path::Path, str::FromStr};
1616

17+
/// Convert an image from NHWC format to NCHW format.
18+
fn nhwc_to_nchw(data: &[u8], dimensions: &Dimensions) -> Vec<u8> {
19+
let mut nchw_data = vec![0; data.len()];
20+
let (height, width, channels) = (
21+
dimensions.height as usize,
22+
dimensions.width as usize,
23+
dimensions.channels as usize,
24+
);
25+
assert_eq!(
26+
data.len(),
27+
height * width * channels * dimensions.precision.bytes()
28+
);
29+
for h in 0..height {
30+
for w in 0..width {
31+
for c in 0..channels {
32+
let nhwc_index =
33+
(h * width * channels + w * channels + c) * dimensions.precision.bytes();
34+
let nchw_index =
35+
(c * height * width + h * width + w) * dimensions.precision.bytes();
36+
for b in 0..dimensions.precision.bytes() {
37+
nchw_data[nchw_index + b] = data[nhwc_index + b];
38+
}
39+
}
40+
}
41+
}
42+
nchw_data
43+
}
44+
1745
/// Convert an image a path to a resized sequence of bytes.
1846
///
1947
/// # Errors
@@ -23,6 +51,7 @@ use std::{num::ParseIntError, path::Path, str::FromStr};
2351
pub fn convert<P: AsRef<Path>>(
2452
path: P,
2553
dimensions: &Dimensions,
54+
format: &str,
2655
) -> Result<Vec<u8>, ConversionError> {
2756
let path = path.as_ref();
2857
info!("Converting {} to {:?}", path.display(), dimensions);
@@ -77,7 +106,12 @@ pub fn convert<P: AsRef<Path>>(
77106

78107
// Copy the bytes of the Mat out to a Vec<u8>.
79108
let dst_slice = unsafe { slice::from_raw_parts(dst.data(), dimensions.bytes()) };
80-
Ok(dst_slice.to_vec())
109+
let nhwc_data = dst_slice.to_vec();
110+
match format {
111+
"nchw" => Ok(nhwc_to_nchw(&nhwc_data, dimensions)),
112+
"nhwc" => Ok(nhwc_data),
113+
_ => Err(ConversionError("Invalid format specified.".to_string())),
114+
}
81115
}
82116

83117
/// Container for the reasons a conversion can fail.

crates/openvino-tensor-converter/src/main.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ fn main() {
66
env_logger::init();
77
let options = Options::from_args();
88
let dimensions = Dimensions::from_str(&options.dimensions).expect("Failed to parse dimensions");
9-
let tensor_data = convert(options.input, &dimensions).expect("Failed to convert image");
9+
let tensor_data =
10+
convert(options.input, &dimensions, &options.format).expect("Failed to convert image");
1011
fs::write(options.output, tensor_data).expect("Failed to write tensor")
1112
}
1213

@@ -27,4 +28,8 @@ struct Options {
2728
/// The dimensions of the output file as "[height]x[width]x[channels]x[precision]"; e.g. 300x300x3xfp32.
2829
#[structopt(name = "OUTPUT DIMENSIONS")]
2930
dimensions: String,
31+
32+
/// Format of the output tensor: "nchw" or "nhwc".
33+
#[structopt(name = "OUTPUT FORMAT", default_value = "nchw")]
34+
format: String,
3035
}

crates/openvino-tensor-converter/tests/conversion.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use openvino_tensor_converter::{convert, Dimensions, Precision};
55
fn same_result_twice_u8() {
66
let input = "tests/test.jpg";
77
let dimensions = Dimensions::new(227, 227, 3, Precision::U8);
8+
let format = "nchw";
89

9-
let first = convert(input, &dimensions).unwrap();
10-
let second = convert(input, &dimensions).unwrap();
10+
let first = convert(input, &dimensions, &format).unwrap();
11+
let second = convert(input, &dimensions, &format).unwrap();
1112
assert_same_bytes(&first, &second);
1213
}
1314

@@ -16,9 +17,10 @@ fn same_result_twice_fp32() {
1617
env_logger::init();
1718
let input = "tests/test.jpg";
1819
let dimensions = Dimensions::new(227, 227, 3, Precision::FP32);
20+
let format = "nhwc";
1921

20-
let first = convert(input, &dimensions).unwrap();
21-
let second = convert(input, &dimensions).unwrap();
22+
let first = convert(input, &dimensions, &format).unwrap();
23+
let second = convert(input, &dimensions, &format).unwrap();
2224
assert_same_bytes(&first, &second);
2325
}
2426

0 commit comments

Comments
 (0)