Skip to content

Commit 4dbec11

Browse files
committed
feat: Address reviews, better limits
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent eb75d6f commit 4dbec11

File tree

3 files changed

+109
-38
lines changed

3 files changed

+109
-38
lines changed

lib/llm/src/preprocessor/media/decoders.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@ use super::common::EncodedMediaData;
77
use ndarray::{ArrayBase, Dimension, OwnedRepr};
88
mod image;
99

10-
pub use image::ImageDecoder;
10+
pub use image::{ImageDecoder, ImageMetadata};
11+
12+
#[derive(Debug)]
13+
pub enum DecodedMediaMetadata {
14+
#[allow(dead_code)] // used in followup MR
15+
Image(ImageMetadata),
16+
}
17+
18+
#[derive(Debug, PartialEq, Eq)]
19+
pub enum DataType {
20+
UINT8,
21+
}
1122

1223
// Decoded media data (image RGB, video frames pixels, ...)
1324
#[derive(Debug)]
@@ -17,7 +28,9 @@ pub struct DecodedMediaData {
1728
#[allow(dead_code)] // used in followup MR
1829
pub(crate) shape: Vec<usize>,
1930
#[allow(dead_code)] // used in followup MR
20-
pub(crate) dtype: String,
31+
pub(crate) dtype: DataType,
32+
#[allow(dead_code)] // used in followup MR
33+
pub(crate) metadata: Option<DecodedMediaMetadata>,
2134
}
2235

2336
// convert Array{N}<u8> to DecodedMediaData
@@ -29,7 +42,8 @@ impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
2942
Self {
3043
data,
3144
shape,
32-
dtype: "uint8".to_string(),
45+
dtype: DataType::UINT8,
46+
metadata: None,
3347
}
3448
}
3549
}

lib/llm/src/preprocessor/media/decoders/image.rs

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,95 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use std::io::Cursor;
5+
46
use anyhow::Result;
5-
use image::GenericImageView;
7+
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
68
use ndarray::Array3;
79

810
use super::super::common::EncodedMediaData;
9-
use super::super::decoders::DecodedMediaData;
11+
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
1012
use super::Decoder;
1113

12-
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
14+
const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB
15+
16+
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
1317
#[serde(deny_unknown_fields)]
1418
pub struct ImageDecoder {
15-
// maximum total size of the image in pixels
1619
#[serde(default)]
17-
pub max_pixels: Option<usize>,
20+
pub(crate) max_image_width: Option<u32>,
21+
#[serde(default)]
22+
pub(crate) max_image_height: Option<u32>,
23+
// maximum allowed total allocation of the decoder in bytes
24+
#[serde(default)]
25+
pub(crate) max_alloc: Option<u64>,
26+
}
27+
28+
impl Default for ImageDecoder {
29+
fn default() -> Self {
30+
Self {
31+
max_image_width: None,
32+
max_image_height: None,
33+
max_alloc: Some(DEFAULT_MAX_ALLOC),
34+
}
35+
}
36+
}
37+
38+
#[derive(Debug)]
39+
pub enum ImageLayout {
40+
HWC,
41+
}
42+
43+
#[derive(Debug)]
44+
pub struct ImageMetadata {
45+
#[allow(dead_code)] // used in followup MR
46+
pub(crate) format: Option<ImageFormat>,
47+
#[allow(dead_code)] // used in followup MR
48+
pub(crate) color_type: ColorType,
49+
#[allow(dead_code)] // used in followup MR
50+
pub(crate) layout: ImageLayout,
1851
}
1952

2053
impl Decoder for ImageDecoder {
2154
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
2255
let bytes = data.into_bytes()?;
23-
let img = image::load_from_memory(&bytes)?;
24-
let (width, height) = img.dimensions();
56+
57+
let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?;
58+
let mut limits = image::Limits::no_limits();
59+
limits.max_image_width = self.max_image_width;
60+
limits.max_image_height = self.max_image_height;
61+
limits.max_alloc = self.max_alloc;
62+
reader.limits(limits);
63+
64+
let format = reader.format();
65+
66+
let img = reader.decode()?;
2567
let n_channels = img.color().channel_count();
2668

27-
let max_pixels = self.max_pixels.unwrap_or(usize::MAX);
28-
let pixel_count = (width as usize)
29-
.checked_mul(height as usize)
30-
.ok_or_else(|| anyhow::anyhow!("Image dimensions {width}x{height} overflow usize"))?;
31-
anyhow::ensure!(
32-
pixel_count <= max_pixels,
33-
"Image dimensions {width}x{height} exceed max pixels {max_pixels}"
34-
);
35-
let data = match n_channels {
36-
1 => img.to_luma8().into_raw(),
37-
2 => img.to_luma_alpha8().into_raw(),
38-
3 => img.to_rgb8().into_raw(),
39-
4 => img.to_rgba8().into_raw(),
69+
let (data, color_type) = match n_channels {
70+
1 => (img.to_luma8().into_raw(), ColorType::L8),
71+
2 => (img.to_luma_alpha8().into_raw(), ColorType::La8),
72+
3 => (img.to_rgb8().into_raw(), ColorType::Rgb8),
73+
4 => (img.to_rgba8().into_raw(), ColorType::Rgba8),
4074
other => anyhow::bail!("Unsupported channel count {other}"),
4175
};
76+
77+
let (width, height) = img.dimensions();
4278
let shape = (height as usize, width as usize, n_channels as usize);
4379
let array = Array3::from_shape_vec(shape, data)?;
44-
Ok(array.into())
80+
let mut decoded: DecodedMediaData = array.into();
81+
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
82+
format,
83+
color_type,
84+
layout: ImageLayout::HWC,
85+
}));
86+
Ok(decoded)
4587
}
4688
}
4789

4890
#[cfg(test)]
4991
mod tests {
92+
use super::super::super::decoders::DataType;
5093
use super::*;
5194
use image::{DynamicImage, ImageBuffer};
5295
use rstest::rstest;
@@ -115,22 +158,30 @@ mod tests {
115158
decoded.shape,
116159
vec![height as usize, width as usize, expected_channels as usize]
117160
);
118-
assert_eq!(decoded.dtype, "uint8");
161+
assert_eq!(decoded.dtype, DataType::UINT8);
119162
}
120163

121164
#[rstest]
122-
#[case(Some(200), 10, 10, image::ImageFormat::Png, true, "within limit")]
123-
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")]
124-
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")]
125-
fn test_pixel_limits(
126-
#[case] max_pixels: Option<usize>,
165+
#[case(Some(100), None, 50, 50, ImageFormat::Png, true, "width ok")]
166+
#[case(Some(50), None, 100, 50, ImageFormat::Jpeg, false, "width too large")]
167+
#[case(None, Some(100), 50, 100, ImageFormat::Png, true, "height ok")]
168+
#[case(None, Some(50), 50, 100, ImageFormat::Png, false, "height too large")]
169+
#[case(None, None, 2000, 2000, ImageFormat::Png, true, "no limits")]
170+
#[case(None, None, 8000, 8000, ImageFormat::Png, false, "alloc too large")]
171+
fn test_limits(
172+
#[case] max_width: Option<u32>,
173+
#[case] max_height: Option<u32>,
127174
#[case] width: u32,
128175
#[case] height: u32,
129176
#[case] format: image::ImageFormat,
130177
#[case] should_succeed: bool,
131178
#[case] test_case: &str,
132179
) {
133-
let decoder = ImageDecoder { max_pixels };
180+
let decoder = ImageDecoder {
181+
max_image_width: max_width,
182+
max_image_height: max_height,
183+
max_alloc: Some(DEFAULT_MAX_ALLOC),
184+
};
134185
let image_bytes = create_test_image(width, height, 3, format); // RGB
135186
let encoded_data = create_encoded_media_data(image_bytes);
136187

@@ -146,7 +197,8 @@ mod tests {
146197
let decoded = result.unwrap();
147198
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
148199
assert_eq!(
149-
decoded.dtype, "uint8",
200+
decoded.dtype,
201+
DataType::UINT8,
150202
"dtype should be uint8 for case: {}",
151203
test_case
152204
);
@@ -159,8 +211,9 @@ mod tests {
159211
);
160212
let error_msg = result.unwrap_err().to_string();
161213
assert!(
162-
error_msg.contains("exceed max pixels"),
163-
"Error should mention exceeding max pixels for case: {}",
214+
error_msg.contains("dimensions") || error_msg.contains("limit"),
215+
"Error should mention dimension limits, got: {} for case: {}",
216+
error_msg,
164217
test_case
165218
);
166219
}
@@ -186,9 +239,11 @@ mod tests {
186239
assert_eq!(decoded.shape[0], 1, "Height should be 1");
187240
assert_eq!(decoded.shape[1], 1, "Width should be 1");
188241
assert_eq!(
189-
decoded.dtype, "uint8",
242+
decoded.dtype,
243+
DataType::UINT8,
190244
"dtype should be uint8 for {} channels {:?}",
191-
input_channels, format
245+
input_channels,
246+
format
192247
);
193248
}
194249
}

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use super::common::EncodedMediaData;
1212
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
1313

1414
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
15+
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
1516

1617
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
1718
pub struct MediaFetcher {
@@ -29,7 +30,7 @@ impl Default for MediaFetcher {
2930
allow_direct_ip: false,
3031
allow_direct_port: false,
3132
allowed_media_domains: None,
32-
timeout: None,
33+
timeout: Some(DEFAULT_HTTP_TIMEOUT),
3334
}
3435
}
3536
}
@@ -117,6 +118,7 @@ impl MediaLoader {
117118

118119
#[cfg(test)]
119120
mod tests {
121+
use super::super::decoders::DataType;
120122
use super::*;
121123
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
122124

@@ -156,7 +158,7 @@ mod tests {
156158
);
157159

158160
let data = result.unwrap();
159-
assert_eq!(data.dtype, "uint8");
161+
assert_eq!(data.dtype, DataType::UINT8);
160162

161163
// Verify image dimensions: 1,999px × 1,125px (width × height)
162164
// Shape format is [height, width, channels]

0 commit comments

Comments
 (0)