@@ -6,14 +6,15 @@ use std::io::Cursor;
66use anyhow:: Result ;
77use image:: { ColorType , GenericImageView , ImageFormat , ImageReader } ;
88use ndarray:: Array3 ;
9+ use serde:: { Deserialize , Serialize } ;
910
1011use super :: super :: common:: EncodedMediaData ;
11- use super :: super :: decoders :: { DecodedMediaData , DecodedMediaMetadata } ;
12- use super :: Decoder ;
12+ use super :: super :: rdma :: DecodedMediaData ;
13+ use super :: { DecodedMediaMetadata , Decoder } ;
1314
1415const DEFAULT_MAX_ALLOC : u64 = 128 * 1024 * 1024 ; // 128 MB
1516
16- #[ derive( Clone , Debug , serde :: Serialize , serde :: Deserialize ) ]
17+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
1718#[ serde( deny_unknown_fields) ]
1819pub struct ImageDecoder {
1920 #[ serde( default ) ]
@@ -36,12 +37,12 @@ impl Default for ImageDecoder {
3637}
3738
3839#[ allow( clippy:: upper_case_acronyms) ]
39- #[ derive( Debug ) ]
40+ #[ derive( Serialize , Deserialize , Clone , Copy , Debug ) ]
4041pub enum ImageLayout {
4142 HWC ,
4243}
4344
44- #[ derive( Debug ) ]
45+ #[ derive( Serialize , Deserialize , Clone , Copy , Debug ) ]
4546pub struct ImageMetadata {
4647 #[ allow( dead_code) ] // used in followup MR
4748 pub ( crate ) format : Option < ImageFormat > ,
@@ -78,8 +79,8 @@ impl Decoder for ImageDecoder {
7879 let ( width, height) = img. dimensions ( ) ;
7980 let shape = ( height as usize , width as usize , n_channels as usize ) ;
8081 let array = Array3 :: from_shape_vec ( shape, data) ?;
81- let mut decoded: DecodedMediaData = array. into ( ) ;
82- decoded. metadata = Some ( DecodedMediaMetadata :: Image ( ImageMetadata {
82+ let mut decoded: DecodedMediaData = array. try_into ( ) ? ;
83+ decoded. tensor_info . metadata = Some ( DecodedMediaMetadata :: Image ( ImageMetadata {
8384 format,
8485 color_type,
8586 layout : ImageLayout :: HWC ,
@@ -90,7 +91,7 @@ impl Decoder for ImageDecoder {
9091
9192#[ cfg( test) ]
9293mod tests {
93- use super :: super :: super :: decoders :: DataType ;
94+ use super :: super :: super :: rdma :: DataType ;
9495 use super :: * ;
9596 use image:: { DynamicImage , ImageBuffer } ;
9697 use rstest:: rstest;
@@ -156,10 +157,10 @@ mod tests {
156157
157158 let decoded = result. unwrap ( ) ;
158159 assert_eq ! (
159- decoded. shape,
160+ decoded. tensor_info . shape,
160161 vec![ height as usize , width as usize , expected_channels as usize ]
161162 ) ;
162- assert_eq ! ( decoded. dtype, DataType :: UINT8 ) ;
163+ assert_eq ! ( decoded. tensor_info . dtype, DataType :: UINT8 ) ;
163164 }
164165
165166 #[ rstest]
@@ -196,9 +197,12 @@ mod tests {
196197 format
197198 ) ;
198199 let decoded = result. unwrap ( ) ;
199- assert_eq ! ( decoded. shape, vec![ height as usize , width as usize , 3 ] ) ;
200200 assert_eq ! (
201- decoded. dtype,
201+ decoded. tensor_info. shape,
202+ vec![ height as usize , width as usize , 3 ]
203+ ) ;
204+ assert_eq ! (
205+ decoded. tensor_info. dtype,
202206 DataType :: UINT8 ,
203207 "dtype should be uint8 for case: {}" ,
204208 test_case
@@ -236,11 +240,15 @@ mod tests {
236240 ) ;
237241
238242 let decoded = result. unwrap ( ) ;
239- assert_eq ! ( decoded. shape. len( ) , 3 , "Should have 3 dimensions" ) ;
240- assert_eq ! ( decoded. shape[ 0 ] , 1 , "Height should be 1" ) ;
241- assert_eq ! ( decoded. shape[ 1 ] , 1 , "Width should be 1" ) ;
242243 assert_eq ! (
243- decoded. dtype,
244+ decoded. tensor_info. shape. len( ) ,
245+ 3 ,
246+ "Should have 3 dimensions"
247+ ) ;
248+ assert_eq ! ( decoded. tensor_info. shape[ 0 ] , 1 , "Height should be 1" ) ;
249+ assert_eq ! ( decoded. tensor_info. shape[ 1 ] , 1 , "Width should be 1" ) ;
250+ assert_eq ! (
251+ decoded. tensor_info. dtype,
244252 DataType :: UINT8 ,
245253 "dtype should be uint8 for {} channels {:?}" ,
246254 input_channels,
0 commit comments