Skip to content

Commit 3b3593f

Browse files
committed
feat: Add some security measures
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent b0aa2e6 commit 3b3593f

File tree

1 file changed

+140
-7
lines changed

1 file changed

+140
-7
lines changed

lib/llm/src/preprocessor/media/loader.rs

Lines changed: 140 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,85 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use std::collections::HashSet;
5+
use std::time::Duration;
6+
47
use anyhow::Result;
58

69
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
710

811
use super::common::EncodedMediaData;
912

10-
// TODO: make this configurable
11-
const HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
13+
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
14+
15+
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
16+
pub struct MediaFetcher {
17+
pub user_agent: String,
18+
pub allow_direct_ip: bool,
19+
pub allow_direct_port: bool,
20+
pub allowed_media_domains: Option<HashSet<String>>,
21+
pub timeout: Option<Duration>,
22+
}
23+
24+
impl Default for MediaFetcher {
25+
fn default() -> Self {
26+
Self {
27+
user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
28+
allow_direct_ip: false,
29+
allow_direct_port: false,
30+
allowed_media_domains: None,
31+
timeout: None,
32+
}
33+
}
34+
}
1235

1336
pub struct MediaLoader {
1437
http_client: reqwest::Client,
38+
media_fetcher: MediaFetcher,
1539
// TODO: decoders, NIXL agent
1640
}
1741

1842
impl MediaLoader {
19-
pub fn new() -> Result<Self> {
20-
let http_client = reqwest::Client::builder()
21-
.user_agent(HTTP_USER_AGENT)
22-
.build()?;
43+
pub fn new(media_fetcher: MediaFetcher) -> Result<Self> {
44+
let mut http_client_builder =
45+
reqwest::Client::builder().user_agent(&media_fetcher.user_agent);
46+
47+
if let Some(timeout) = media_fetcher.timeout {
48+
http_client_builder = http_client_builder.timeout(timeout);
49+
}
50+
51+
let http_client = http_client_builder.build()?;
2352

24-
Ok(Self { http_client })
53+
Ok(Self {
54+
http_client,
55+
media_fetcher,
56+
})
57+
}
58+
59+
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
60+
if !matches!(url.scheme(), "http" | "https" | "data") {
61+
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
62+
}
63+
64+
if url.scheme() == "data" {
65+
return Ok(());
66+
}
67+
68+
if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_)))
69+
{
70+
anyhow::bail!("Direct IP access is not allowed");
71+
}
72+
if !self.media_fetcher.allow_direct_port && url.port().is_some() {
73+
anyhow::bail!("Direct port access is not allowed");
74+
}
75+
if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains
76+
&& let Some(host) = url.host_str()
77+
&& !allowed_domains.contains(host)
78+
{
79+
anyhow::bail!("Domain '{host}' is not in allowed list");
80+
}
81+
82+
Ok(())
2583
}
2684

2785
pub async fn fetch_media_part(
@@ -34,10 +92,12 @@ impl MediaLoader {
3492
let data = match oai_content_part {
3593
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
3694
let url = &image_part.image_url.url;
95+
self.check_if_url_allowed(url)?;
3796
EncodedMediaData::from_url(url, &self.http_client).await?
3897
}
3998
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
4099
let url = &video_part.video_url.url;
100+
self.check_if_url_allowed(url)?;
41101
EncodedMediaData::from_url(url, &self.http_client).await?
42102
}
43103
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
@@ -49,3 +109,76 @@ impl MediaLoader {
49109
Ok(data)
50110
}
51111
}
112+
113+
#[cfg(test)]
114+
mod tests {
115+
use super::*;
116+
117+
#[test]
118+
fn test_direct_ip_blocked() {
119+
let fetcher = MediaFetcher {
120+
allow_direct_ip: false,
121+
..Default::default()
122+
};
123+
let loader = MediaLoader::new(fetcher).unwrap();
124+
125+
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
126+
let result = loader.check_if_url_allowed(&url);
127+
128+
assert!(result.is_err());
129+
assert!(
130+
result
131+
.unwrap_err()
132+
.to_string()
133+
.contains("Direct IP access is not allowed")
134+
);
135+
}
136+
137+
#[test]
138+
fn test_direct_port_blocked() {
139+
let fetcher = MediaFetcher {
140+
allow_direct_port: false,
141+
..Default::default()
142+
};
143+
let loader = MediaLoader::new(fetcher).unwrap();
144+
145+
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
146+
let result = loader.check_if_url_allowed(&url);
147+
148+
assert!(result.is_err());
149+
assert!(
150+
result
151+
.unwrap_err()
152+
.to_string()
153+
.contains("Direct port access is not allowed")
154+
);
155+
}
156+
157+
#[test]
158+
fn test_domain_allowlist() {
159+
let mut allowed_domains = HashSet::new();
160+
allowed_domains.insert("trusted.com".to_string());
161+
allowed_domains.insert("example.com".to_string());
162+
163+
let fetcher = MediaFetcher {
164+
allowed_media_domains: Some(allowed_domains),
165+
..Default::default()
166+
};
167+
let loader = MediaLoader::new(fetcher).unwrap();
168+
169+
// Allowed domain should pass
170+
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
171+
assert!(loader.check_if_url_allowed(&url).is_ok());
172+
173+
// Disallowed domain should fail
174+
let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
175+
let result = loader.check_if_url_allowed(&url);
176+
assert!(result.is_err());
177+
assert!(
178+
result
179+
.unwrap_err()
180+
.to_string()
181+
.contains("not in allowed list")
182+
);
183+
}
184+
}

0 commit comments

Comments
 (0)