Skip to content

Commit bed5bd0

Browse files
committed
feat: Address reviews, better limits
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent 1ba939b commit bed5bd0

File tree

3 files changed

+110
-38
lines changed

3 files changed

+110
-38
lines changed

lib/llm/src/preprocessor/media/decoders.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@ use super::common::EncodedMediaData;
77
use ndarray::{ArrayBase, Dimension, OwnedRepr};
88
mod image;
99

10-
pub use image::ImageDecoder;
10+
pub use image::{ImageDecoder, ImageMetadata};
11+
12+
#[derive(Debug)]
13+
pub enum DecodedMediaMetadata {
14+
#[allow(dead_code)] // used in followup MR
15+
Image(ImageMetadata),
16+
}
17+
18+
#[derive(Debug, PartialEq, Eq)]
19+
pub enum DataType {
20+
UINT8,
21+
}
1122

1223
// Decoded media data (image RGB, video frames pixels, ...)
1324
#[derive(Debug)]
@@ -17,7 +28,9 @@ pub struct DecodedMediaData {
1728
#[allow(dead_code)] // used in followup MR
1829
pub(crate) shape: Vec<usize>,
1930
#[allow(dead_code)] // used in followup MR
20-
pub(crate) dtype: String,
31+
pub(crate) dtype: DataType,
32+
#[allow(dead_code)] // used in followup MR
33+
pub(crate) metadata: Option<DecodedMediaMetadata>,
2134
}
2235

2336
// convert Array{N}<u8> to DecodedMediaData
@@ -29,7 +42,8 @@ impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
2942
Self {
3043
data,
3144
shape,
32-
dtype: "uint8".to_string(),
45+
dtype: DataType::UINT8,
46+
metadata: None,
3347
}
3448
}
3549
}

lib/llm/src/preprocessor/media/decoders/image.rs

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,96 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use std::io::Cursor;
5+
46
use anyhow::Result;
5-
use image::GenericImageView;
7+
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
68
use ndarray::Array3;
79

810
use super::super::common::EncodedMediaData;
9-
use super::super::decoders::DecodedMediaData;
11+
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
1012
use super::Decoder;
1113

12-
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
14+
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
15+
16+
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
1317
#[serde(deny_unknown_fields)]
1418
pub struct ImageDecoder {
15-
// maximum total size of the image in pixels
1619
#[serde(default)]
17-
pub max_pixels: Option<usize>,
20+
pub(crate) max_image_width: Option<u32>,
21+
#[serde(default)]
22+
pub(crate) max_image_height: Option<u32>,
23+
// maximum allowed total allocation of the decoder in bytes
24+
#[serde(default)]
25+
pub(crate) max_alloc: Option<u64>,
26+
}
27+
28+
impl Default for ImageDecoder {
29+
fn default() -> Self {
30+
Self {
31+
max_image_width: None,
32+
max_image_height: None,
33+
max_alloc: Some(DEFAULT_MAX_ALLOC),
34+
}
35+
}
36+
}
37+
38+
#[allow(clippy::upper_case_acronyms)]
39+
#[derive(Debug)]
40+
pub enum ImageLayout {
41+
HWC,
42+
}
43+
44+
#[derive(Debug)]
45+
pub struct ImageMetadata {
46+
#[allow(dead_code)] // used in followup MR
47+
pub(crate) format: Option<ImageFormat>,
48+
#[allow(dead_code)] // used in followup MR
49+
pub(crate) color_type: ColorType,
50+
#[allow(dead_code)] // used in followup MR
51+
pub(crate) layout: ImageLayout,
1852
}
1953

2054
impl Decoder for ImageDecoder {
2155
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
2256
let bytes = data.into_bytes()?;
23-
let img = image::load_from_memory(&bytes)?;
24-
let (width, height) = img.dimensions();
57+
58+
let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?;
59+
let mut limits = image::Limits::no_limits();
60+
limits.max_image_width = self.max_image_width;
61+
limits.max_image_height = self.max_image_height;
62+
limits.max_alloc = self.max_alloc;
63+
reader.limits(limits);
64+
65+
let format = reader.format();
66+
67+
let img = reader.decode()?;
2568
let n_channels = img.color().channel_count();
2669

27-
let max_pixels = self.max_pixels.unwrap_or(usize::MAX);
28-
let pixel_count = (width as usize)
29-
.checked_mul(height as usize)
30-
.ok_or_else(|| anyhow::anyhow!("Image dimensions {width}x{height} overflow usize"))?;
31-
anyhow::ensure!(
32-
pixel_count <= max_pixels,
33-
"Image dimensions {width}x{height} exceed max pixels {max_pixels}"
34-
);
35-
let data = match n_channels {
36-
1 => img.to_luma8().into_raw(),
37-
2 => img.to_luma_alpha8().into_raw(),
38-
3 => img.to_rgb8().into_raw(),
39-
4 => img.to_rgba8().into_raw(),
70+
let (data, color_type) = match n_channels {
71+
1 => (img.to_luma8().into_raw(), ColorType::L8),
72+
2 => (img.to_luma_alpha8().into_raw(), ColorType::La8),
73+
3 => (img.to_rgb8().into_raw(), ColorType::Rgb8),
74+
4 => (img.to_rgba8().into_raw(), ColorType::Rgba8),
4075
other => anyhow::bail!("Unsupported channel count {other}"),
4176
};
77+
78+
let (width, height) = img.dimensions();
4279
let shape = (height as usize, width as usize, n_channels as usize);
4380
let array = Array3::from_shape_vec(shape, data)?;
44-
Ok(array.into())
81+
let mut decoded: DecodedMediaData = array.into();
82+
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
83+
format,
84+
color_type,
85+
layout: ImageLayout::HWC,
86+
}));
87+
Ok(decoded)
4588
}
4689
}
4790

4891
#[cfg(test)]
4992
mod tests {
93+
use super::super::super::decoders::DataType;
5094
use super::*;
5195
use image::{DynamicImage, ImageBuffer};
5296
use rstest::rstest;
@@ -115,22 +159,30 @@ mod tests {
115159
decoded.shape,
116160
vec![height as usize, width as usize, expected_channels as usize]
117161
);
118-
assert_eq!(decoded.dtype, "uint8");
162+
assert_eq!(decoded.dtype, DataType::UINT8);
119163
}
120164

121165
#[rstest]
122-
#[case(Some(200), 10, 10, image::ImageFormat::Png, true, "within limit")]
123-
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")]
124-
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")]
125-
fn test_pixel_limits(
126-
#[case] max_pixels: Option<usize>,
166+
#[case(Some(100), None, 50, 50, ImageFormat::Png, true, "width ok")]
167+
#[case(Some(50), None, 100, 50, ImageFormat::Jpeg, false, "width too large")]
168+
#[case(None, Some(100), 50, 100, ImageFormat::Png, true, "height ok")]
169+
#[case(None, Some(50), 50, 100, ImageFormat::Png, false, "height too large")]
170+
#[case(None, None, 2000, 2000, ImageFormat::Png, true, "no limits")]
171+
#[case(None, None, 8000, 8000, ImageFormat::Png, false, "alloc too large")]
172+
fn test_limits(
173+
#[case] max_width: Option<u32>,
174+
#[case] max_height: Option<u32>,
127175
#[case] width: u32,
128176
#[case] height: u32,
129177
#[case] format: image::ImageFormat,
130178
#[case] should_succeed: bool,
131179
#[case] test_case: &str,
132180
) {
133-
let decoder = ImageDecoder { max_pixels };
181+
let decoder = ImageDecoder {
182+
max_image_width: max_width,
183+
max_image_height: max_height,
184+
max_alloc: Some(DEFAULT_MAX_ALLOC),
185+
};
134186
let image_bytes = create_test_image(width, height, 3, format); // RGB
135187
let encoded_data = create_encoded_media_data(image_bytes);
136188

@@ -146,7 +198,8 @@ mod tests {
146198
let decoded = result.unwrap();
147199
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
148200
assert_eq!(
149-
decoded.dtype, "uint8",
201+
decoded.dtype,
202+
DataType::UINT8,
150203
"dtype should be uint8 for case: {}",
151204
test_case
152205
);
@@ -159,8 +212,9 @@ mod tests {
159212
);
160213
let error_msg = result.unwrap_err().to_string();
161214
assert!(
162-
error_msg.contains("exceed max pixels"),
163-
"Error should mention exceeding max pixels for case: {}",
215+
error_msg.contains("dimensions") || error_msg.contains("limit"),
216+
"Error should mention dimension limits, got: {} for case: {}",
217+
error_msg,
164218
test_case
165219
);
166220
}
@@ -186,9 +240,11 @@ mod tests {
186240
assert_eq!(decoded.shape[0], 1, "Height should be 1");
187241
assert_eq!(decoded.shape[1], 1, "Width should be 1");
188242
assert_eq!(
189-
decoded.dtype, "uint8",
243+
decoded.dtype,
244+
DataType::UINT8,
190245
"dtype should be uint8 for {} channels {:?}",
191-
input_channels, format
246+
input_channels,
247+
format
192248
);
193249
}
194250
}

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use super::common::EncodedMediaData;
1212
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
1313

1414
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
15+
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
1516

1617
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
1718
pub struct MediaFetcher {
@@ -29,7 +30,7 @@ impl Default for MediaFetcher {
2930
allow_direct_ip: false,
3031
allow_direct_port: false,
3132
allowed_media_domains: None,
32-
timeout: None,
33+
timeout: Some(DEFAULT_HTTP_TIMEOUT),
3334
}
3435
}
3536
}
@@ -117,6 +118,7 @@ impl MediaLoader {
117118

118119
#[cfg(test)]
119120
mod tests {
121+
use super::super::decoders::DataType;
120122
use super::*;
121123
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
122124

@@ -156,7 +158,7 @@ mod tests {
156158
);
157159

158160
let data = result.unwrap();
159-
assert_eq!(data.dtype, "uint8");
161+
assert_eq!(data.dtype, DataType::UINT8);
160162

161163
// Verify image dimensions: 1,999px × 1,125px (width × height)
162164
// Shape format is [height, width, channels]

0 commit comments

Comments
 (0)