11// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// SPDX-License-Identifier: Apache-2.0
33
4+ use std:: io:: Cursor ;
5+
46use anyhow:: Result ;
5- use image:: GenericImageView ;
7+ use image:: { ColorType , GenericImageView , ImageFormat , ImageReader } ;
68use ndarray:: Array3 ;
79
810use super :: super :: common:: EncodedMediaData ;
9- use super :: super :: decoders:: DecodedMediaData ;
11+ use super :: super :: decoders:: { DecodedMediaData , DecodedMediaMetadata } ;
1012use super :: Decoder ;
1113
12- #[ derive( Clone , Debug , Default , serde:: Serialize , serde:: Deserialize ) ]
14+ const DEFAULT_MAX_ALLOC : u64 = 128 * 1024 * 1024 ; // 128 MB
15+
16+ #[ derive( Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
1317#[ serde( deny_unknown_fields) ]
1418pub struct ImageDecoder {
15- // maximum total size of the image in pixels
1619 #[ serde( default ) ]
17- pub max_pixels : Option < usize > ,
20+ pub ( crate ) max_image_width : Option < u32 > ,
21+ #[ serde( default ) ]
22+ pub ( crate ) max_image_height : Option < u32 > ,
23+ // maximum allowed total allocation of the decoder in bytes
24+ #[ serde( default ) ]
25+ pub ( crate ) max_alloc : Option < u64 > ,
26+ }
27+
28+ impl Default for ImageDecoder {
29+ fn default ( ) -> Self {
30+ Self {
31+ max_image_width : None ,
32+ max_image_height : None ,
33+ max_alloc : Some ( DEFAULT_MAX_ALLOC ) ,
34+ }
35+ }
36+ }
37+
38+ #[ allow( clippy:: upper_case_acronyms) ]
39+ #[ derive( Debug ) ]
40+ pub enum ImageLayout {
41+ HWC ,
42+ }
43+
44+ #[ derive( Debug ) ]
45+ pub struct ImageMetadata {
46+ #[ allow( dead_code) ] // used in followup MR
47+ pub ( crate ) format : Option < ImageFormat > ,
48+ #[ allow( dead_code) ] // used in followup MR
49+ pub ( crate ) color_type : ColorType ,
50+ #[ allow( dead_code) ] // used in followup MR
51+ pub ( crate ) layout : ImageLayout ,
1852}
1953
2054impl Decoder for ImageDecoder {
2155 fn decode ( & self , data : EncodedMediaData ) -> Result < DecodedMediaData > {
2256 let bytes = data. into_bytes ( ) ?;
23- let img = image:: load_from_memory ( & bytes) ?;
24- let ( width, height) = img. dimensions ( ) ;
57+
58+ let mut reader = ImageReader :: new ( Cursor :: new ( bytes) ) . with_guessed_format ( ) ?;
59+ let mut limits = image:: Limits :: no_limits ( ) ;
60+ limits. max_image_width = self . max_image_width ;
61+ limits. max_image_height = self . max_image_height ;
62+ limits. max_alloc = self . max_alloc ;
63+ reader. limits ( limits) ;
64+
65+ let format = reader. format ( ) ;
66+
67+ let img = reader. decode ( ) ?;
2568 let n_channels = img. color ( ) . channel_count ( ) ;
2669
27- let max_pixels = self . max_pixels . unwrap_or ( usize:: MAX ) ;
28- let pixel_count = ( width as usize )
29- . checked_mul ( height as usize )
30- . ok_or_else ( || anyhow:: anyhow!( "Image dimensions {width}x{height} overflow usize" ) ) ?;
31- anyhow:: ensure!(
32- pixel_count <= max_pixels,
33- "Image dimensions {width}x{height} exceed max pixels {max_pixels}"
34- ) ;
35- let data = match n_channels {
36- 1 => img. to_luma8 ( ) . into_raw ( ) ,
37- 2 => img. to_luma_alpha8 ( ) . into_raw ( ) ,
38- 3 => img. to_rgb8 ( ) . into_raw ( ) ,
39- 4 => img. to_rgba8 ( ) . into_raw ( ) ,
70+ let ( data, color_type) = match n_channels {
71+ 1 => ( img. to_luma8 ( ) . into_raw ( ) , ColorType :: L8 ) ,
72+ 2 => ( img. to_luma_alpha8 ( ) . into_raw ( ) , ColorType :: La8 ) ,
73+ 3 => ( img. to_rgb8 ( ) . into_raw ( ) , ColorType :: Rgb8 ) ,
74+ 4 => ( img. to_rgba8 ( ) . into_raw ( ) , ColorType :: Rgba8 ) ,
4075 other => anyhow:: bail!( "Unsupported channel count {other}" ) ,
4176 } ;
77+
78+ let ( width, height) = img. dimensions ( ) ;
4279 let shape = ( height as usize , width as usize , n_channels as usize ) ;
4380 let array = Array3 :: from_shape_vec ( shape, data) ?;
44- Ok ( array. into ( ) )
81+ let mut decoded: DecodedMediaData = array. into ( ) ;
82+ decoded. metadata = Some ( DecodedMediaMetadata :: Image ( ImageMetadata {
83+ format,
84+ color_type,
85+ layout : ImageLayout :: HWC ,
86+ } ) ) ;
87+ Ok ( decoded)
4588 }
4689}
4790
4891#[ cfg( test) ]
4992mod tests {
93+ use super :: super :: super :: decoders:: DataType ;
5094 use super :: * ;
5195 use image:: { DynamicImage , ImageBuffer } ;
5296 use rstest:: rstest;
@@ -115,22 +159,30 @@ mod tests {
115159 decoded. shape,
116160 vec![ height as usize , width as usize , expected_channels as usize ]
117161 ) ;
118- assert_eq ! ( decoded. dtype, "uint8" ) ;
162+ assert_eq ! ( decoded. dtype, DataType :: UINT8 ) ;
119163 }
120164
121165 #[ rstest]
122- #[ case( Some ( 200 ) , 10 , 10 , image:: ImageFormat :: Png , true , "within limit" ) ]
123- #[ case( Some ( 50 ) , 10 , 10 , image:: ImageFormat :: Jpeg , false , "exceeds limit" ) ]
124- #[ case( None , 200 , 300 , image:: ImageFormat :: Png , true , "no limit" ) ]
125- fn test_pixel_limits (
126- #[ case] max_pixels : Option < usize > ,
166+ #[ case( Some ( 100 ) , None , 50 , 50 , ImageFormat :: Png , true , "width ok" ) ]
167+ #[ case( Some ( 50 ) , None , 100 , 50 , ImageFormat :: Jpeg , false , "width too large" ) ]
168+ #[ case( None , Some ( 100 ) , 50 , 100 , ImageFormat :: Png , true , "height ok" ) ]
169+ #[ case( None , Some ( 50 ) , 50 , 100 , ImageFormat :: Png , false , "height too large" ) ]
170+ #[ case( None , None , 2000 , 2000 , ImageFormat :: Png , true , "no limits" ) ]
171+ #[ case( None , None , 8000 , 8000 , ImageFormat :: Png , false , "alloc too large" ) ]
172+ fn test_limits (
173+ #[ case] max_width : Option < u32 > ,
174+ #[ case] max_height : Option < u32 > ,
127175 #[ case] width : u32 ,
128176 #[ case] height : u32 ,
129177 #[ case] format : image:: ImageFormat ,
130178 #[ case] should_succeed : bool ,
131179 #[ case] test_case : & str ,
132180 ) {
133- let decoder = ImageDecoder { max_pixels } ;
181+ let decoder = ImageDecoder {
182+ max_image_width : max_width,
183+ max_image_height : max_height,
184+ max_alloc : Some ( DEFAULT_MAX_ALLOC ) ,
185+ } ;
134186 let image_bytes = create_test_image ( width, height, 3 , format) ; // RGB
135187 let encoded_data = create_encoded_media_data ( image_bytes) ;
136188
@@ -146,7 +198,8 @@ mod tests {
146198 let decoded = result. unwrap ( ) ;
147199 assert_eq ! ( decoded. shape, vec![ height as usize , width as usize , 3 ] ) ;
148200 assert_eq ! (
149- decoded. dtype, "uint8" ,
201+ decoded. dtype,
202+ DataType :: UINT8 ,
150203 "dtype should be uint8 for case: {}" ,
151204 test_case
152205 ) ;
@@ -159,8 +212,9 @@ mod tests {
159212 ) ;
160213 let error_msg = result. unwrap_err ( ) . to_string ( ) ;
161214 assert ! (
162- error_msg. contains( "exceed max pixels" ) ,
163- "Error should mention exceeding max pixels for case: {}" ,
215+ error_msg. contains( "dimensions" ) || error_msg. contains( "limit" ) ,
216+ "Error should mention dimension limits, got: {} for case: {}" ,
217+ error_msg,
164218 test_case
165219 ) ;
166220 }
@@ -186,9 +240,11 @@ mod tests {
186240 assert_eq ! ( decoded. shape[ 0 ] , 1 , "Height should be 1" ) ;
187241 assert_eq ! ( decoded. shape[ 1 ] , 1 , "Width should be 1" ) ;
188242 assert_eq ! (
189- decoded. dtype, "uint8" ,
243+ decoded. dtype,
244+ DataType :: UINT8 ,
190245 "dtype should be uint8 for {} channels {:?}" ,
191- input_channels, format
246+ input_channels,
247+ format
192248 ) ;
193249 }
194250}
0 commit comments