diff --git a/Cargo.lock b/Cargo.lock index d88ddd271c..64a7cbbb8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,6 +210,16 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "assert_matches" version = "1.5.0" @@ -2170,6 +2180,7 @@ dependencies = [ "async_zmq", "axum 0.8.4", "axum-server", + "base64 0.22.1", "bincode 2.0.1", "bitflags 2.9.4", "blake3", @@ -2201,6 +2212,7 @@ dependencies = [ "lazy_static", "minijinja", "minijinja-contrib", + "mockito", "modelexpress-client", "modelexpress-common", "ndarray", @@ -4911,6 +4923,30 @@ dependencies = [ "rayon", ] +[[package]] +name = "mockito" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-util", + "log", + "rand 0.9.2", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "modelexpress-client" version = "0.2.0" diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index d045152ac9..27abe8bc6c 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -140,6 +140,10 @@ minijinja = { version = "2.10.2", features = ["loader"] } minijinja-contrib = { version = "2.10.2", features = ["pycompat"] } json-five = { version = "0.3" } +# media loading in the preprocessor +reqwest = { workspace = true } +base64 = { version = "0.22" } + # Publishers zeromq = "0.4.1" rmp-serde = "1.3" @@ -167,6 +171,7 @@ insta = { version = "1.41", features = [ ] } lazy_static = "1.4" +mockito = "1.7.0" [build-dependencies] tonic-build = { version = "0.13.1" } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 258358e8fa..04fd1cd230 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -11,6 +11,7 @@ //! //! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest. +pub mod media; pub mod prompt; pub mod tools; use anyhow::Context; @@ -26,11 +27,11 @@ use std::{collections::HashMap, pin::Pin, sync::Arc}; use tracing; use crate::model_card::{ModelDeploymentCard, ModelInfo}; +use crate::preprocessor::media::MediaLoader; use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::protocols::common::preprocessor::{ MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, }; - use crate::tokenizers::Encoding; use dynamo_parsers::{ReasoningParser, ReasoningParserType}; @@ -113,6 +114,7 @@ pub struct OpenAIPreprocessor { /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser) runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig, tool_call_parser: Option, + media_loader: Option, } impl OpenAIPreprocessor { @@ -141,7 +143,7 @@ impl OpenAIPreprocessor { // // Initialize runtime config from the ModelDeploymentCard let runtime_config = mdc.runtime_config.clone(); - + let media_loader = None; // TODO: enable with decoder config from MDC Ok(Arc::new(Self { formatter, tokenizer, @@ -149,6 +151,7 @@ impl OpenAIPreprocessor { mdcsum, runtime_config, tool_call_parser, + media_loader, })) } /// Encode a string to it's tokens @@ -162,7 +165,7 @@ impl OpenAIPreprocessor { /// Annotations evaluated by this method include: /// - `formatted_prompt` /// - `token_ids` - pub fn preprocess_request< + pub async fn preprocess_request< R: OAIChatLikeRequest + AnnotationsProvider + SamplingOptionsProvider @@ -181,6 +184,7 @@ impl OpenAIPreprocessor { .gather_tokens(request, &mut builder, formatted_prompt) .with_context(|| "Failed to gather tokens")?; self.gather_multi_modal_data(request, &mut builder) + .await .with_context(|| "Failed to gather multimodal data")?; Ok((builder.build()?, annotations)) @@ -267,7 +271,7 @@ impl OpenAIPreprocessor { } } - pub fn gather_multi_modal_data( + pub async fn gather_multi_modal_data( &self, request: &R, builder: &mut PreprocessedRequestBuilder, @@ -275,6 +279,7 @@ impl OpenAIPreprocessor { let messages = request.messages(); let message_count = messages.len().unwrap_or(0); let mut media_map: MultimodalDataMap = HashMap::new(); + let mut fetch_tasks = Vec::new(); for idx in 0..message_count { let msg = messages @@ -307,10 +312,31 @@ impl OpenAIPreprocessor { _ => continue, }; - let map_item = media_map.entry(type_str.clone()).or_default(); - map_item.push(MultimodalData::Url(url)); + if self.media_loader.is_some() { + fetch_tasks.push((type_str, content_part.clone())); + } else { + // No loader, just pass the URL through + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Url(url)); + } } } + + // Execute all fetch tasks + if !fetch_tasks.is_empty() { + let loader = self.media_loader.as_ref().unwrap(); + let _results = futures::future::join_all( + fetch_tasks + .iter() + .map(|(_, content_part)| loader.fetch_media_part(content_part)), + ) + .await; + + // TODO: decode and pass NIXL descriptors to the media map + } + if !media_map.is_empty() { builder.multi_modal_data(Some(media_map)); } @@ -839,7 +865,7 @@ impl let response_generator = request.response_generator(context.id().to_string()); // convert the chat completion request to a common completion request - let (common_request, annotations) = self.preprocess_request(&request)?; + let (common_request, annotations) = self.preprocess_request(&request).await?; let mut response_generator = Box::new(response_generator); @@ -974,7 +1000,7 @@ impl // convert the chat completion request to a common completion request let mut builder = self.builder(&request)?; let annotations = self.gather_tokens(&request, &mut builder, None)?; - self.gather_multi_modal_data(&request, &mut builder)?; + self.gather_multi_modal_data(&request, &mut builder).await?; let common_request = builder.build()?; diff --git a/lib/llm/src/preprocessor/media.rs b/lib/llm/src/preprocessor/media.rs new file mode 100644 index 0000000000..9b4af1f64b --- /dev/null +++ b/lib/llm/src/preprocessor/media.rs @@ -0,0 +1,8 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod common; +mod loader; + +pub use common::EncodedMediaData; +pub use loader::MediaLoader; diff --git a/lib/llm/src/preprocessor/media/common.rs b/lib/llm/src/preprocessor/media/common.rs new file mode 100644 index 0000000000..10e5406241 --- /dev/null +++ b/lib/llm/src/preprocessor/media/common.rs @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use base64::{Engine as _, engine::general_purpose}; + +// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded +#[derive(Debug)] +pub struct EncodedMediaData { + pub(crate) bytes: Vec, + pub(crate) b64_encoded: bool, +} + +impl EncodedMediaData { + // Handles both web URLs (will download the bytes) and data URLs (will keep b64-encoded) + pub async fn from_url(url: &url::Url, client: &reqwest::Client) -> Result { + let (bytes, b64_encoded) = match url.scheme() { + "data" => { + let base64_data = url + .as_str() + .split_once(',') + .ok_or_else(|| anyhow::anyhow!("Invalid media data URL format"))? + .1; + anyhow::ensure!(!base64_data.is_empty(), "Media data URL is empty"); + (base64_data.as_bytes().to_vec(), true) + } + "http" | "https" => { + let bytes = client + .get(url.to_string()) + .send() + .await? + .error_for_status()? + .bytes() + .await?; + anyhow::ensure!(!bytes.is_empty(), "Media URL is empty"); + (bytes.to_vec(), false) + } + scheme => anyhow::bail!("Unsupported media URL scheme: {scheme}"), + }; + + Ok(Self { bytes, b64_encoded }) + } + + // Potentially decodes b64 bytes + pub fn into_bytes(self) -> Result> { + if self.b64_encoded { + Ok(general_purpose::STANDARD.decode(self.bytes)?) + } else { + Ok(self.bytes) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_from_base64() { + // Simple base64 encoded "test" string: dGVzdA== + let data_url = url::Url::parse("data:text/plain;base64,dGVzdA==").unwrap(); + let client = reqwest::Client::new(); + + let result = EncodedMediaData::from_url(&data_url, &client) + .await + .unwrap(); + + assert!(result.b64_encoded); + assert_eq!(result.bytes, b"dGVzdA=="); + let decoded = result.into_bytes().unwrap(); + assert_eq!(decoded, b"test"); + } + + #[tokio::test] + async fn test_from_empty_base64() { + let data_url = url::Url::parse("data:text/plain;base64,").unwrap(); + let client = reqwest::Client::new(); + + let result = EncodedMediaData::from_url(&data_url, &client).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_from_invalid_base64() { + let data_url = url::Url::parse("data:invalid").unwrap(); + let client = reqwest::Client::new(); + + let result = EncodedMediaData::from_url(&data_url, &client).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_from_url_http() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/image.png") + .with_status(200) + .with_body(b"test data") + .create_async() + .await; + + let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap(); + let client = reqwest::Client::new(); + + let result = EncodedMediaData::from_url(&url, &client).await.unwrap(); + + assert!(!result.b64_encoded); + assert_eq!(result.bytes, b"test data"); + let decoded = result.into_bytes().unwrap(); + assert_eq!(decoded, b"test data"); + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_from_url_http_404() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/image.png") + .with_status(404) + .create_async() + .await; + + let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap(); + let client = reqwest::Client::new(); + let result = EncodedMediaData::from_url(&url, &client).await; + assert!(result.is_err()); + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_from_unsupported_scheme() { + let ftp_url = url::Url::parse("ftp://example.com/image.png").unwrap(); + let client = reqwest::Client::new(); + + let result = EncodedMediaData::from_url(&ftp_url, &client).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unsupported media URL scheme") + ); + } +} diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs new file mode 100644 index 0000000000..47b2516e66 --- /dev/null +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashSet; +use std::time::Duration; + +use anyhow::Result; + +use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart; + +use super::common::EncodedMediaData; + +const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct MediaFetcher { + pub user_agent: String, + pub allow_direct_ip: bool, + pub allow_direct_port: bool, + pub allowed_media_domains: Option>, + pub timeout: Option, +} + +impl Default for MediaFetcher { + fn default() -> Self { + Self { + user_agent: DEFAULT_HTTP_USER_AGENT.to_string(), + allow_direct_ip: false, + allow_direct_port: false, + allowed_media_domains: None, + timeout: None, + } + } +} + +pub struct MediaLoader { + http_client: reqwest::Client, + media_fetcher: MediaFetcher, + // TODO: decoders, NIXL agent +} + +impl MediaLoader { + pub fn new(media_fetcher: MediaFetcher) -> Result { + let mut http_client_builder = + reqwest::Client::builder().user_agent(&media_fetcher.user_agent); + + if let Some(timeout) = media_fetcher.timeout { + http_client_builder = http_client_builder.timeout(timeout); + } + + let http_client = http_client_builder.build()?; + + Ok(Self { + http_client, + media_fetcher, + }) + } + + pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> { + if !matches!(url.scheme(), "http" | "https" | "data") { + anyhow::bail!("Only HTTP(S) and data URLs are allowed"); + } + + if url.scheme() == "data" { + return Ok(()); + } + + if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) + { + anyhow::bail!("Direct IP access is not allowed"); + } + if !self.media_fetcher.allow_direct_port && url.port().is_some() { + anyhow::bail!("Direct port access is not allowed"); + } + if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains + && let Some(host) = url.host_str() + && !allowed_domains.contains(host) + { + anyhow::bail!("Domain '{host}' is not in allowed list"); + } + + Ok(()) + } + + pub async fn fetch_media_part( + &self, + oai_content_part: &ChatCompletionRequestUserMessageContentPart, + // TODO: request-level options + ) -> Result { + // fetch the media + // TODO: decode and NIXL-register + let data = match oai_content_part { + ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { + let url = &image_part.image_url.url; + self.check_if_url_allowed(url)?; + EncodedMediaData::from_url(url, &self.http_client).await? + } + ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { + let url = &video_part.video_url.url; + self.check_if_url_allowed(url)?; + EncodedMediaData::from_url(url, &self.http_client).await? + } + ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { + anyhow::bail!("Audio decoding is not supported yet"); + } + _ => anyhow::bail!("Unsupported media type"), + }; + + Ok(data) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_direct_ip_blocked() { + let fetcher = MediaFetcher { + allow_direct_ip: false, + ..Default::default() + }; + let loader = MediaLoader::new(fetcher).unwrap(); + + let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap(); + let result = loader.check_if_url_allowed(&url); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Direct IP access is not allowed") + ); + } + + #[test] + fn test_direct_port_blocked() { + let fetcher = MediaFetcher { + allow_direct_port: false, + ..Default::default() + }; + let loader = MediaLoader::new(fetcher).unwrap(); + + let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap(); + let result = loader.check_if_url_allowed(&url); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Direct port access is not allowed") + ); + } + + #[test] + fn test_domain_allowlist() { + let mut allowed_domains = HashSet::new(); + allowed_domains.insert("trusted.com".to_string()); + allowed_domains.insert("example.com".to_string()); + + let fetcher = MediaFetcher { + allowed_media_domains: Some(allowed_domains), + ..Default::default() + }; + let loader = MediaLoader::new(fetcher).unwrap(); + + // Allowed domain should pass + let url = url::Url::parse("https://trusted.com/image.jpg").unwrap(); + assert!(loader.check_if_url_allowed(&url).is_ok()); + + // Disallowed domain should fail + let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap(); + let result = loader.check_if_url_allowed(&url); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("not in allowed list") + ); + } +} diff --git a/lib/llm/tests/preprocessor.rs b/lib/llm/tests/preprocessor.rs index 84a108ebc8..6d4b350294 100644 --- a/lib/llm/tests/preprocessor.rs +++ b/lib/llm/tests/preprocessor.rs @@ -551,7 +551,7 @@ async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) { let message = build_message("Test multimodal content", media_chunks); let request = Request::from(&message, None, None, mdc.slug().to_string()); - let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).unwrap(); + let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).await.unwrap(); // Verify multimodal data handling if media_chunks.is_empty() {