Skip to content

Commit 22aa4a9

Browse files
committed
feat: Media HTTP fetching and b64 decoding
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent 7835904 commit 22aa4a9

File tree

7 files changed

+282
-9
lines changed

7 files changed

+282
-9
lines changed

Cargo.lock

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/llm/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ minijinja = { version = "2.10.2", features = ["loader"] }
140140
minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
141141
json-five = { version = "0.3" }
142142

143+
# media loading in the preprocessor
144+
reqwest = { workspace = true }
145+
base64 = { version = "0.22" }
146+
143147
# Publishers
144148
zeromq = "0.4.1"
145149
rmp-serde = "1.3"
@@ -167,6 +171,7 @@ insta = { version = "1.41", features = [
167171
] }
168172

169173
lazy_static = "1.4"
174+
mockito = "1.7.0"
170175

171176
[build-dependencies]
172177
tonic-build = { version = "0.13.1" }

lib/llm/src/preprocessor.rs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//!
1212
//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
1313
14+
pub mod media;
1415
pub mod prompt;
1516
pub mod tools;
1617
use anyhow::Context;
@@ -26,11 +27,11 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
2627
use tracing;
2728

2829
use crate::model_card::{ModelDeploymentCard, ModelInfo};
30+
use crate::preprocessor::media::MediaLoader;
2931
use crate::preprocessor::prompt::OAIChatLikeRequest;
3032
use crate::protocols::common::preprocessor::{
3133
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
3234
};
33-
3435
use crate::tokenizers::Encoding;
3536

3637
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
@@ -113,6 +114,7 @@ pub struct OpenAIPreprocessor {
113114
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
114115
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
115116
tool_call_parser: Option<String>,
117+
media_loader: Option<MediaLoader>,
116118
}
117119

118120
impl OpenAIPreprocessor {
@@ -141,14 +143,15 @@ impl OpenAIPreprocessor {
141143

142144
// // Initialize runtime config from the ModelDeploymentCard
143145
let runtime_config = mdc.runtime_config.clone();
144-
146+
let media_loader = None; // TODO: enable with decoder config from MDC
145147
Ok(Arc::new(Self {
146148
formatter,
147149
tokenizer,
148150
model_info,
149151
mdcsum,
150152
runtime_config,
151153
tool_call_parser,
154+
media_loader,
152155
}))
153156
}
154157
/// Encode a string to it's tokens
@@ -162,7 +165,7 @@ impl OpenAIPreprocessor {
162165
/// Annotations evaluated by this method include:
163166
/// - `formatted_prompt`
164167
/// - `token_ids`
165-
pub fn preprocess_request<
168+
pub async fn preprocess_request<
166169
R: OAIChatLikeRequest
167170
+ AnnotationsProvider
168171
+ SamplingOptionsProvider
@@ -181,6 +184,7 @@ impl OpenAIPreprocessor {
181184
.gather_tokens(request, &mut builder, formatted_prompt)
182185
.with_context(|| "Failed to gather tokens")?;
183186
self.gather_multi_modal_data(request, &mut builder)
187+
.await
184188
.with_context(|| "Failed to gather multimodal data")?;
185189

186190
Ok((builder.build()?, annotations))
@@ -267,14 +271,15 @@ impl OpenAIPreprocessor {
267271
}
268272
}
269273

270-
pub fn gather_multi_modal_data<R: OAIChatLikeRequest>(
274+
pub async fn gather_multi_modal_data<R: OAIChatLikeRequest>(
271275
&self,
272276
request: &R,
273277
builder: &mut PreprocessedRequestBuilder,
274278
) -> Result<()> {
275279
let messages = request.messages();
276280
let message_count = messages.len().unwrap_or(0);
277281
let mut media_map: MultimodalDataMap = HashMap::new();
282+
let mut fetch_tasks = Vec::new();
278283

279284
for idx in 0..message_count {
280285
let msg = messages
@@ -307,10 +312,31 @@ impl OpenAIPreprocessor {
307312
_ => continue,
308313
};
309314

310-
let map_item = media_map.entry(type_str.clone()).or_default();
311-
map_item.push(MultimodalData::Url(url));
315+
if self.media_loader.is_some() {
316+
fetch_tasks.push((type_str, content_part.clone()));
317+
} else {
318+
// No loader, just pass the URL through
319+
media_map
320+
.entry(type_str)
321+
.or_default()
322+
.push(MultimodalData::Url(url));
323+
}
312324
}
313325
}
326+
327+
// Execute all fetch tasks
328+
if !fetch_tasks.is_empty() {
329+
let loader = self.media_loader.as_ref().unwrap();
330+
let _results = futures::future::join_all(
331+
fetch_tasks
332+
.iter()
333+
.map(|(_, content_part)| loader.fetch_media_part(content_part)),
334+
)
335+
.await;
336+
337+
// TODO: decode and pass NIXL descriptors to the media map
338+
}
339+
314340
if !media_map.is_empty() {
315341
builder.multi_modal_data(Some(media_map));
316342
}
@@ -839,7 +865,7 @@ impl
839865
let response_generator = request.response_generator(context.id().to_string());
840866

841867
// convert the chat completion request to a common completion request
842-
let (common_request, annotations) = self.preprocess_request(&request)?;
868+
let (common_request, annotations) = self.preprocess_request(&request).await?;
843869

844870
let mut response_generator = Box::new(response_generator);
845871

@@ -974,7 +1000,7 @@ impl
9741000
// convert the chat completion request to a common completion request
9751001
let mut builder = self.builder(&request)?;
9761002
let annotations = self.gather_tokens(&request, &mut builder, None)?;
977-
self.gather_multi_modal_data(&request, &mut builder)?;
1003+
self.gather_multi_modal_data(&request, &mut builder).await?;
9781004

9791005
let common_request = builder.build()?;
9801006

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use anyhow::Result;
5+
use base64::{Engine as _, engine::general_purpose};
6+
7+
// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded
8+
#[derive(Debug)]
9+
pub struct EncodedMediaData {
10+
pub(crate) bytes: Vec<u8>,
11+
pub(crate) b64_encoded: bool,
12+
}
13+
14+
impl EncodedMediaData {
15+
// Handles both web URLs (will download the bytes) and data URLs (will keep b64-encoded)
16+
// This function is kept in tokio runtime so we do not want any expensive operations
17+
pub async fn from_url(url: &url::Url, client: &reqwest::Client) -> Result<Self> {
18+
let (bytes, b64_encoded) = match url.scheme() {
19+
"data" => {
20+
let base64_data = url
21+
.as_str()
22+
.split_once(',')
23+
.ok_or_else(|| anyhow::anyhow!("Invalid media data URL format"))?
24+
.1;
25+
anyhow::ensure!(!base64_data.is_empty(), "Media data URL is empty");
26+
(base64_data.as_bytes().to_vec(), true)
27+
}
28+
"http" | "https" => {
29+
let bytes = client
30+
.get(url.to_string())
31+
.send()
32+
.await?
33+
.error_for_status()?
34+
.bytes()
35+
.await?;
36+
anyhow::ensure!(!bytes.is_empty(), "Media URL is empty");
37+
(bytes.to_vec(), false)
38+
}
39+
scheme => anyhow::bail!("Unsupported media URL scheme: {scheme}"),
40+
};
41+
42+
Ok(Self { bytes, b64_encoded })
43+
}
44+
45+
// Potentially decodes b64 bytes
46+
pub fn into_bytes(self) -> Result<Vec<u8>> {
47+
if self.b64_encoded {
48+
Ok(general_purpose::STANDARD.decode(self.bytes)?)
49+
} else {
50+
Ok(self.bytes)
51+
}
52+
}
53+
}
54+
55+
#[cfg(test)]
56+
mod tests {
57+
use super::*;
58+
59+
#[tokio::test]
60+
async fn test_from_base64() {
61+
// Simple base64 encoded "test" string: dGVzdA==
62+
let data_url = url::Url::parse("data:text/plain;base64,dGVzdA==").unwrap();
63+
let client = reqwest::Client::new();
64+
65+
let result = EncodedMediaData::from_url(&data_url, &client)
66+
.await
67+
.unwrap();
68+
69+
assert!(result.b64_encoded);
70+
assert_eq!(result.bytes, b"dGVzdA==");
71+
let decoded = result.into_bytes().unwrap();
72+
assert_eq!(decoded, b"test");
73+
}
74+
75+
#[tokio::test]
76+
async fn test_from_empty_base64() {
77+
let data_url = url::Url::parse("data:text/plain;base64,").unwrap();
78+
let client = reqwest::Client::new();
79+
80+
let result = EncodedMediaData::from_url(&data_url, &client).await;
81+
assert!(result.is_err());
82+
}
83+
84+
#[tokio::test]
85+
async fn test_from_invalid_base64() {
86+
let data_url = url::Url::parse("data:invalid").unwrap();
87+
let client = reqwest::Client::new();
88+
89+
let result = EncodedMediaData::from_url(&data_url, &client).await;
90+
assert!(result.is_err());
91+
}
92+
93+
#[tokio::test]
94+
async fn test_from_url_http() {
95+
let mut server = mockito::Server::new_async().await;
96+
let mock = server
97+
.mock("GET", "/image.png")
98+
.with_status(200)
99+
.with_body(b"test data")
100+
.create_async()
101+
.await;
102+
103+
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
104+
let client = reqwest::Client::new();
105+
106+
let result = EncodedMediaData::from_url(&url, &client).await.unwrap();
107+
108+
assert!(!result.b64_encoded);
109+
assert_eq!(result.bytes, b"test data");
110+
let decoded = result.into_bytes().unwrap();
111+
assert_eq!(decoded, b"test data");
112+
113+
mock.assert_async().await;
114+
}
115+
116+
#[tokio::test]
117+
async fn test_from_url_http_404() {
118+
let mut server = mockito::Server::new_async().await;
119+
let mock = server
120+
.mock("GET", "/image.png")
121+
.with_status(404)
122+
.create_async()
123+
.await;
124+
125+
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
126+
let client = reqwest::Client::new();
127+
let result = EncodedMediaData::from_url(&url, &client).await;
128+
assert!(result.is_err());
129+
130+
mock.assert_async().await;
131+
}
132+
133+
#[tokio::test]
134+
async fn test_from_unsupported_scheme() {
135+
let ftp_url = url::Url::parse("ftp://example.com/image.png").unwrap();
136+
let client = reqwest::Client::new();
137+
138+
let result = EncodedMediaData::from_url(&ftp_url, &client).await;
139+
assert!(result.is_err());
140+
assert!(
141+
result
142+
.unwrap_err()
143+
.to_string()
144+
.contains("Unsupported media URL scheme")
145+
);
146+
}
147+
}

0 commit comments

Comments
 (0)