11// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// SPDX-License-Identifier: Apache-2.0
33
4+ use std:: collections:: HashSet ;
5+ use std:: time:: Duration ;
6+
47use anyhow:: Result ;
58
69use dynamo_async_openai:: types:: ChatCompletionRequestUserMessageContentPart ;
710
811use super :: common:: EncodedMediaData ;
912
10- // TODO: make this configurable
11- const HTTP_USER_AGENT : & str = "dynamo-ai/dynamo" ;
13+ const DEFAULT_HTTP_USER_AGENT : & str = "dynamo-ai/dynamo" ;
14+
15+ #[ derive( Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
16+ pub struct MediaFetcher {
17+ pub user_agent : String ,
18+ pub allow_direct_ip : bool ,
19+ pub allow_direct_port : bool ,
20+ pub allowed_media_domains : Option < HashSet < String > > ,
21+ pub timeout : Option < Duration > ,
22+ }
23+
24+ impl Default for MediaFetcher {
25+ fn default ( ) -> Self {
26+ Self {
27+ user_agent : DEFAULT_HTTP_USER_AGENT . to_string ( ) ,
28+ allow_direct_ip : false ,
29+ allow_direct_port : false ,
30+ allowed_media_domains : None ,
31+ timeout : None ,
32+ }
33+ }
34+ }
1235
1336pub struct MediaLoader {
1437 http_client : reqwest:: Client ,
38+ media_fetcher : MediaFetcher ,
1539 // TODO: decoders, NIXL agent
1640}
1741
1842impl MediaLoader {
19- pub fn new ( ) -> Result < Self > {
20- let http_client = reqwest:: Client :: builder ( )
21- . user_agent ( HTTP_USER_AGENT )
22- . build ( ) ?;
43+ pub fn new ( media_fetcher : MediaFetcher ) -> Result < Self > {
44+ let mut http_client_builder =
45+ reqwest:: Client :: builder ( ) . user_agent ( & media_fetcher. user_agent ) ;
46+
47+ if let Some ( timeout) = media_fetcher. timeout {
48+ http_client_builder = http_client_builder. timeout ( timeout) ;
49+ }
50+
51+ let http_client = http_client_builder. build ( ) ?;
2352
24- Ok ( Self { http_client } )
53+ Ok ( Self {
54+ http_client,
55+ media_fetcher,
56+ } )
57+ }
58+
59+ pub fn check_if_url_allowed ( & self , url : & url:: Url ) -> Result < ( ) > {
60+ if !matches ! ( url. scheme( ) , "http" | "https" | "data" ) {
61+ anyhow:: bail!( "Only HTTP(S) and data URLs are allowed" ) ;
62+ }
63+
64+ if url. scheme ( ) == "data" {
65+ return Ok ( ( ) ) ;
66+ }
67+
68+ if !self . media_fetcher . allow_direct_ip && !matches ! ( url. host( ) , Some ( url:: Host :: Domain ( _) ) )
69+ {
70+ anyhow:: bail!( "Direct IP access is not allowed" ) ;
71+ }
72+ if !self . media_fetcher . allow_direct_port && url. port ( ) . is_some ( ) {
73+ anyhow:: bail!( "Direct port access is not allowed" ) ;
74+ }
75+ if let Some ( allowed_domains) = & self . media_fetcher . allowed_media_domains
76+ && let Some ( host) = url. host_str ( )
77+ && !allowed_domains. contains ( host)
78+ {
79+ anyhow:: bail!( "Domain '{host}' is not in allowed list" ) ;
80+ }
81+
82+ Ok ( ( ) )
2583 }
2684
2785 pub async fn fetch_media_part (
@@ -34,10 +92,12 @@ impl MediaLoader {
3492 let data = match oai_content_part {
3593 ChatCompletionRequestUserMessageContentPart :: ImageUrl ( image_part) => {
3694 let url = & image_part. image_url . url ;
95+ self . check_if_url_allowed ( url) ?;
3796 EncodedMediaData :: from_url ( url, & self . http_client ) . await ?
3897 }
3998 ChatCompletionRequestUserMessageContentPart :: VideoUrl ( video_part) => {
4099 let url = & video_part. video_url . url ;
100+ self . check_if_url_allowed ( url) ?;
41101 EncodedMediaData :: from_url ( url, & self . http_client ) . await ?
42102 }
43103 ChatCompletionRequestUserMessageContentPart :: AudioUrl ( _) => {
@@ -49,3 +109,76 @@ impl MediaLoader {
49109 Ok ( data)
50110 }
51111}
112+
113+ #[ cfg( test) ]
114+ mod tests {
115+ use super :: * ;
116+
117+ #[ test]
118+ fn test_direct_ip_blocked ( ) {
119+ let fetcher = MediaFetcher {
120+ allow_direct_ip : false ,
121+ ..Default :: default ( )
122+ } ;
123+ let loader = MediaLoader :: new ( fetcher) . unwrap ( ) ;
124+
125+ let url = url:: Url :: parse ( "http://192.168.1.1/image.jpg" ) . unwrap ( ) ;
126+ let result = loader. check_if_url_allowed ( & url) ;
127+
128+ assert ! ( result. is_err( ) ) ;
129+ assert ! (
130+ result
131+ . unwrap_err( )
132+ . to_string( )
133+ . contains( "Direct IP access is not allowed" )
134+ ) ;
135+ }
136+
137+ #[ test]
138+ fn test_direct_port_blocked ( ) {
139+ let fetcher = MediaFetcher {
140+ allow_direct_port : false ,
141+ ..Default :: default ( )
142+ } ;
143+ let loader = MediaLoader :: new ( fetcher) . unwrap ( ) ;
144+
145+ let url = url:: Url :: parse ( "http://example.com:8080/image.jpg" ) . unwrap ( ) ;
146+ let result = loader. check_if_url_allowed ( & url) ;
147+
148+ assert ! ( result. is_err( ) ) ;
149+ assert ! (
150+ result
151+ . unwrap_err( )
152+ . to_string( )
153+ . contains( "Direct port access is not allowed" )
154+ ) ;
155+ }
156+
157+ #[ test]
158+ fn test_domain_allowlist ( ) {
159+ let mut allowed_domains = HashSet :: new ( ) ;
160+ allowed_domains. insert ( "trusted.com" . to_string ( ) ) ;
161+ allowed_domains. insert ( "example.com" . to_string ( ) ) ;
162+
163+ let fetcher = MediaFetcher {
164+ allowed_media_domains : Some ( allowed_domains) ,
165+ ..Default :: default ( )
166+ } ;
167+ let loader = MediaLoader :: new ( fetcher) . unwrap ( ) ;
168+
169+ // Allowed domain should pass
170+ let url = url:: Url :: parse ( "https://trusted.com/image.jpg" ) . unwrap ( ) ;
171+ assert ! ( loader. check_if_url_allowed( & url) . is_ok( ) ) ;
172+
173+ // Disallowed domain should fail
174+ let url = url:: Url :: parse ( "https://untrusted.com/image.jpg" ) . unwrap ( ) ;
175+ let result = loader. check_if_url_allowed ( & url) ;
176+ assert ! ( result. is_err( ) ) ;
177+ assert ! (
178+ result
179+ . unwrap_err( )
180+ . to_string( )
181+ . contains( "not in allowed list" )
182+ ) ;
183+ }
184+ }
0 commit comments