11// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// SPDX-License-Identifier: Apache-2.0
33
4+ use std:: io:: Cursor ;
5+
46use anyhow:: Result ;
5- use image:: GenericImageView ;
7+ use image:: { ColorType , GenericImageView , ImageFormat , ImageReader } ;
68use ndarray:: Array3 ;
79
810use super :: super :: common:: EncodedMediaData ;
9- use super :: super :: decoders:: DecodedMediaData ;
11+ use super :: super :: decoders:: { DecodedMediaData , DecodedMediaMetadata } ;
1012use super :: Decoder ;
1113
12- #[ derive( Clone , Debug , Default , serde:: Serialize , serde:: Deserialize ) ]
14+ const DEFAULT_MAX_ALLOC : u64 = 128 * 1024 * 1024 ; // 128 MB
15+
16+ #[ derive( Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
1317#[ serde( deny_unknown_fields) ]
1418pub struct ImageDecoder {
15- // maximum total size of the image in pixels
1619 #[ serde( default ) ]
17- pub max_pixels : Option < usize > ,
20+ pub ( crate ) max_image_width : Option < u32 > ,
21+ #[ serde( default ) ]
22+ pub ( crate ) max_image_height : Option < u32 > ,
23+ // maximum allowed total allocation of the decoder in bytes
24+ #[ serde( default ) ]
25+ pub ( crate ) max_alloc : Option < u64 > ,
26+ }
27+
28+ impl Default for ImageDecoder {
29+ fn default ( ) -> Self {
30+ Self {
31+ max_image_width : None ,
32+ max_image_height : None ,
33+ max_alloc : Some ( DEFAULT_MAX_ALLOC ) ,
34+ }
35+ }
36+ }
37+
38+ #[ derive( Debug ) ]
39+ pub enum ImageLayout {
40+ HWC ,
41+ }
42+
43+ #[ derive( Debug ) ]
44+ pub struct ImageMetadata {
45+ #[ allow( dead_code) ] // used in followup MR
46+ pub ( crate ) format : Option < ImageFormat > ,
47+ #[ allow( dead_code) ] // used in followup MR
48+ pub ( crate ) color_type : ColorType ,
49+ #[ allow( dead_code) ] // used in followup MR
50+ pub ( crate ) layout : ImageLayout ,
1851}
1952
2053impl Decoder for ImageDecoder {
2154 fn decode ( & self , data : EncodedMediaData ) -> Result < DecodedMediaData > {
2255 let bytes = data. into_bytes ( ) ?;
23- let img = image:: load_from_memory ( & bytes) ?;
24- let ( width, height) = img. dimensions ( ) ;
56+
57+ let mut reader = ImageReader :: new ( Cursor :: new ( bytes) ) . with_guessed_format ( ) ?;
58+ let mut limits = image:: Limits :: no_limits ( ) ;
59+ limits. max_image_width = self . max_image_width ;
60+ limits. max_image_height = self . max_image_height ;
61+ limits. max_alloc = self . max_alloc ;
62+ reader. limits ( limits) ;
63+
64+ let format = reader. format ( ) ;
65+
66+ let img = reader. decode ( ) ?;
2567 let n_channels = img. color ( ) . channel_count ( ) ;
2668
27- let max_pixels = self . max_pixels . unwrap_or ( usize:: MAX ) ;
28- let pixel_count = ( width as usize )
29- . checked_mul ( height as usize )
30- . ok_or_else ( || anyhow:: anyhow!( "Image dimensions {width}x{height} overflow usize" ) ) ?;
31- anyhow:: ensure!(
32- pixel_count <= max_pixels,
33- "Image dimensions {width}x{height} exceed max pixels {max_pixels}"
34- ) ;
35- let data = match n_channels {
36- 1 => img. to_luma8 ( ) . into_raw ( ) ,
37- 2 => img. to_luma_alpha8 ( ) . into_raw ( ) ,
38- 3 => img. to_rgb8 ( ) . into_raw ( ) ,
39- 4 => img. to_rgba8 ( ) . into_raw ( ) ,
69+ let ( data, color_type) = match n_channels {
70+ 1 => ( img. to_luma8 ( ) . into_raw ( ) , ColorType :: L8 ) ,
71+ 2 => ( img. to_luma_alpha8 ( ) . into_raw ( ) , ColorType :: La8 ) ,
72+ 3 => ( img. to_rgb8 ( ) . into_raw ( ) , ColorType :: Rgb8 ) ,
73+ 4 => ( img. to_rgba8 ( ) . into_raw ( ) , ColorType :: Rgba8 ) ,
4074 other => anyhow:: bail!( "Unsupported channel count {other}" ) ,
4175 } ;
76+
77+ let ( width, height) = img. dimensions ( ) ;
4278 let shape = ( height as usize , width as usize , n_channels as usize ) ;
4379 let array = Array3 :: from_shape_vec ( shape, data) ?;
44- Ok ( array. into ( ) )
80+ let mut decoded: DecodedMediaData = array. into ( ) ;
81+ decoded. metadata = Some ( DecodedMediaMetadata :: Image ( ImageMetadata {
82+ format,
83+ color_type,
84+ layout : ImageLayout :: HWC ,
85+ } ) ) ;
86+ Ok ( decoded)
4587 }
4688}
4789
4890#[ cfg( test) ]
4991mod tests {
92+ use super :: super :: super :: decoders:: DataType ;
5093 use super :: * ;
5194 use image:: { DynamicImage , ImageBuffer } ;
5295 use rstest:: rstest;
@@ -115,22 +158,30 @@ mod tests {
115158 decoded. shape,
116159 vec![ height as usize , width as usize , expected_channels as usize ]
117160 ) ;
118- assert_eq ! ( decoded. dtype, "uint8" ) ;
161+ assert_eq ! ( decoded. dtype, DataType :: UINT8 ) ;
119162 }
120163
121164 #[ rstest]
122- #[ case( Some ( 200 ) , 10 , 10 , image:: ImageFormat :: Png , true , "within limit" ) ]
123- #[ case( Some ( 50 ) , 10 , 10 , image:: ImageFormat :: Jpeg , false , "exceeds limit" ) ]
124- #[ case( None , 200 , 300 , image:: ImageFormat :: Png , true , "no limit" ) ]
125- fn test_pixel_limits (
126- #[ case] max_pixels : Option < usize > ,
165+ #[ case( Some ( 100 ) , None , 50 , 50 , ImageFormat :: Png , true , "width ok" ) ]
166+ #[ case( Some ( 50 ) , None , 100 , 50 , ImageFormat :: Jpeg , false , "width too large" ) ]
167+ #[ case( None , Some ( 100 ) , 50 , 100 , ImageFormat :: Png , true , "height ok" ) ]
168+ #[ case( None , Some ( 50 ) , 50 , 100 , ImageFormat :: Png , false , "height too large" ) ]
169+ #[ case( None , None , 2000 , 2000 , ImageFormat :: Png , true , "no limits" ) ]
170+ #[ case( None , None , 8000 , 8000 , ImageFormat :: Png , false , "alloc too large" ) ]
171+ fn test_limits (
172+ #[ case] max_width : Option < u32 > ,
173+ #[ case] max_height : Option < u32 > ,
127174 #[ case] width : u32 ,
128175 #[ case] height : u32 ,
129176 #[ case] format : image:: ImageFormat ,
130177 #[ case] should_succeed : bool ,
131178 #[ case] test_case : & str ,
132179 ) {
133- let decoder = ImageDecoder { max_pixels } ;
180+ let decoder = ImageDecoder {
181+ max_image_width : max_width,
182+ max_image_height : max_height,
183+ max_alloc : Some ( DEFAULT_MAX_ALLOC ) ,
184+ } ;
134185 let image_bytes = create_test_image ( width, height, 3 , format) ; // RGB
135186 let encoded_data = create_encoded_media_data ( image_bytes) ;
136187
@@ -146,7 +197,8 @@ mod tests {
146197 let decoded = result. unwrap ( ) ;
147198 assert_eq ! ( decoded. shape, vec![ height as usize , width as usize , 3 ] ) ;
148199 assert_eq ! (
149- decoded. dtype, "uint8" ,
200+ decoded. dtype,
201+ DataType :: UINT8 ,
150202 "dtype should be uint8 for case: {}" ,
151203 test_case
152204 ) ;
@@ -159,8 +211,9 @@ mod tests {
159211 ) ;
160212 let error_msg = result. unwrap_err ( ) . to_string ( ) ;
161213 assert ! (
162- error_msg. contains( "exceed max pixels" ) ,
163- "Error should mention exceeding max pixels for case: {}" ,
214+ error_msg. contains( "dimensions" ) || error_msg. contains( "limit" ) ,
215+ "Error should mention dimension limits, got: {} for case: {}" ,
216+ error_msg,
164217 test_case
165218 ) ;
166219 }
@@ -186,9 +239,11 @@ mod tests {
186239 assert_eq ! ( decoded. shape[ 0 ] , 1 , "Height should be 1" ) ;
187240 assert_eq ! ( decoded. shape[ 1 ] , 1 , "Width should be 1" ) ;
188241 assert_eq ! (
189- decoded. dtype, "uint8" ,
242+ decoded. dtype,
243+ DataType :: UINT8 ,
190244 "dtype should be uint8 for {} channels {:?}" ,
191- input_channels, format
245+ input_channels,
246+ format
192247 ) ;
193248 }
194249}
0 commit comments