Skip to content

Commit e630a9f

Browse files
committed
fix: Guard against ssrf attacks when creating image input
1 parent 19747d2 commit e630a9f

File tree

7 files changed

+391
-6
lines changed

7 files changed

+391
-6
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Export the main classes for easy import
2+
from .anti_ssrf import AntiSSRF
3+
from .anti_ssrf_policy import AntiSSRFPolicy
4+
from .exceptions import AntiSSRFException
5+
6+
# Make classes available for import
7+
__all__ = ["AntiSSRF", "AntiSSRFPolicy", "AntiSSRFException"]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Export the main classes for easy import
2+
from typing import List, Optional
3+
4+
from .anti_ssrf_policy import AntiSSRFPolicy
5+
from .exceptions import AntiSSRFException
6+
7+
8+
# Simple wrapper class that provides validate_url method
9+
class AntiSSRF:
10+
"""Anti-SSRF protection class that validates URLs and network connections."""
11+
12+
def __init__(self, policy: Optional[AntiSSRFPolicy] = None):
13+
"""Initialize AntiSSRF with an optional custom policy."""
14+
self.policy: AntiSSRFPolicy = policy if policy is not None else AntiSSRFPolicy(use_defaults=True)
15+
16+
def validate_url(self, url: str, headers={}) -> None:
17+
"""
18+
Validate a URL against the Anti-SSRF policy.
19+
20+
Args:
21+
url: The URL to validate
22+
23+
Raises:
24+
AntiSSRFException: If the URL is not allowed by the policy
25+
"""
26+
if not url:
27+
return
28+
29+
from urllib.parse import urlparse
30+
31+
# Parse the URL
32+
try:
33+
parsed_url = urlparse(url)
34+
except Exception as e:
35+
raise AntiSSRFException(f"Invalid URL format: {e}")
36+
37+
if not parsed_url.hostname:
38+
raise AntiSSRFException("URL must have a hostname")
39+
40+
# Resolve DNS and check network connections
41+
if parsed_url.hostname != "registries" and parsed_url.hostname != "location.api.azureml.ms":
42+
dns_resolved_ips = self._resolve_hostname(parsed_url.hostname)
43+
44+
if not self.policy.is_network_connection_allowed(dns_resolved_ips):
45+
raise AntiSSRFException(f"Network connection to '{parsed_url.hostname}' is not allowed")
46+
47+
# Check HTTP scheme
48+
if not self.policy.is_http_request_allowed(parsed_url.scheme, headers):
49+
raise AntiSSRFException(f"HTTP scheme '{parsed_url.scheme}' is not allowed")
50+
51+
def _resolve_hostname(self, hostname: str) -> List[str]:
52+
"""Resolve hostname to IP addresses."""
53+
import ipaddress
54+
import socket
55+
56+
# Handle localhost explicitly
57+
if hostname.lower() == "localhost":
58+
return ["127.0.0.1"]
59+
60+
# Try to parse as IP address first
61+
try:
62+
ip_address = ipaddress.ip_address(hostname)
63+
return [str(ip_address)]
64+
except ValueError:
65+
pass # Not an IP address, continue with DNS resolution
66+
67+
# Perform DNS resolution
68+
try:
69+
_, _, ip_addresses = socket.gethostbyname_ex(hostname)
70+
if not ip_addresses:
71+
raise AntiSSRFException(f"No IP addresses resolved for hostname: {hostname}")
72+
return ip_addresses
73+
except socket.gaierror as e:
74+
raise AntiSSRFException(f"DNS resolution failed for hostname '{hostname}': {e}")
75+
except Exception as e:
76+
raise AntiSSRFException(f"Error resolving hostname '{hostname}': {e}")
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import ipaddress
2+
from typing import List, Optional
3+
4+
from .cidr_helpers import IPNetwork, try_parse_cidr_string
5+
from .exceptions import AntiSSRFException
6+
7+
8+
class AntiSSRFPolicy:
9+
def __init__(self, use_defaults: bool = True):
10+
self.AllowedAddresses: List[IPNetwork] = []
11+
self.DeniedAddresses: List[IPNetwork] = []
12+
self.DeniedHeaders: List[str] = []
13+
self.RequiredHeaders: List[str] = []
14+
self.AllowPlainTextHttp: bool = False
15+
self.AddXFFHeader: bool = True
16+
self.DenyAllUnspecifiedIPs: bool = False
17+
18+
if use_defaults:
19+
self._set_defaults()
20+
21+
def add_allowed_addresses(self, networks: List[str]) -> bool:
22+
for network in networks:
23+
outnet = try_parse_cidr_string(network)
24+
self.AllowedAddresses.append(outnet)
25+
return True
26+
27+
def add_denied_addresses(self, networks: List[str]) -> bool:
28+
if self.DenyAllUnspecifiedIPs:
29+
raise AntiSSRFException("Can't add denied networks when * is already supplied")
30+
if not networks:
31+
raise AntiSSRFException("Bad networks parameter")
32+
if len(networks) == 1 and networks[0] == "*":
33+
if len(self.DeniedAddresses) > 0:
34+
raise AntiSSRFException("Can't add * when deny list already has entries")
35+
self.DenyAllUnspecifiedIPs = True
36+
return True
37+
else:
38+
for network in networks:
39+
outnet = try_parse_cidr_string(network)
40+
self.DeniedAddresses.append(outnet)
41+
return True
42+
43+
def add_denied_headers(self, denied_headers: Optional[List[str]]) -> None:
44+
if denied_headers:
45+
self.DeniedHeaders.extend(denied_headers)
46+
47+
def add_required_headers(self, required_headers: Optional[List[str]]) -> None:
48+
if required_headers:
49+
self.RequiredHeaders.extend(required_headers)
50+
51+
def set_allow_plain_text_http(self, allow_plain_text_http: bool = False) -> None:
52+
self.AllowPlainTextHttp = allow_plain_text_http
53+
54+
def add_xff(self, add_xff: bool = True) -> None:
55+
self.AddXFFHeader = add_xff
56+
57+
# IP Addresses in Deny List can be IPv4, IPv6 or IPv4 mapped to IPv6
58+
# Accordingly, to check if an input address from DNS resolution is to be denied, we should:
59+
# 1. Check if the input IP is an IPv4 address, and then check if it is present in deny list
60+
# as a pure IPv4 or an IPv4 mapped to IPv6 format
61+
# 2. Check if the input IP is an IPv6 address, and then check if it is present in deny list
62+
# as a pure IPv6 address. This includes addresses in IPv4 mapped to IPv6 format
63+
# 3. Check if the input IP is an IPv4 mapped to IPv6, then check if it is present in the
64+
# deny list as an IPv4 mapped IPv6 address, then convert it to IPv4 and check if it is
65+
# present in the deny list as a pure IPv4 address
66+
#
67+
# For example, 169.254.169.254, if present in the deny list, should deny DNS resolved
68+
# addresses 169.254.169.254 and ::ffff:a9fe:a9fe
69+
# Likewise ::ffff:a9fe:a9fe, if present in the deny list, should deny DNS resolved
70+
# addresses ::ffff:a9fe:a9fe and 169.254.169.254
71+
#
72+
# Such case-by-case comparisons leads to a lot of branches in code leading to
73+
# sphagettification and also makes code difficult to follow and maintain
74+
# Furthermore, the complexity gets compounded if one adds an allow list to the mix
75+
#
76+
# To make things easier and efficient, we convert every IPv4 address to IPv6 across the
77+
# deny list, allow list and also the input DNS resolved addresses
78+
# The CIDR helper class is accordingly written
79+
#
80+
# As IPv6 is the future anyway, this also makes the code future proof
81+
def is_network_connection_allowed(self, dns_resolved_ip_addresses: List[str]) -> bool:
82+
for ip_str in dns_resolved_ip_addresses:
83+
ip_address = ipaddress.ip_address(ip_str)
84+
ipv6_address = (
85+
ip_address
86+
if isinstance(ip_address, ipaddress.IPv6Address)
87+
else ipaddress.IPv6Address(f"::ffff:{ip_address}")
88+
)
89+
90+
if self.DenyAllUnspecifiedIPs:
91+
# If the address is not in an allow list, it's not allowed.
92+
if not self._networks_contain_address(self.AllowedAddresses, ipv6_address):
93+
return False
94+
elif self.DeniedAddresses:
95+
# If address is in deny list and not in allow list, it's not allowed.
96+
if self._networks_contain_address(
97+
self.DeniedAddresses, ipv6_address
98+
) and not self._networks_contain_address(self.AllowedAddresses, ipv6_address):
99+
return False
100+
# No IP addresses returned by DNS resolution were denied
101+
return True
102+
103+
@staticmethod
104+
def _networks_contain_address(networks: List[IPNetwork], address: ipaddress.IPv6Address) -> bool:
105+
for network in networks:
106+
if network.contains(address):
107+
return True
108+
return False
109+
110+
def is_http_request_allowed(self, scheme: str, headers: dict) -> bool:
111+
if scheme.lower() == "http" and not self.AllowPlainTextHttp:
112+
return False
113+
114+
if self.AddXFFHeader:
115+
if "X-Forwarded-For" not in headers:
116+
headers["X-Forwarded-For"] = "true"
117+
118+
if self.DeniedHeaders:
119+
for header in self.DeniedHeaders:
120+
if header in headers:
121+
return False
122+
123+
if self.RequiredHeaders:
124+
for header in self.RequiredHeaders:
125+
if header not in headers:
126+
return False
127+
128+
return True
129+
130+
def _set_defaults(self):
131+
self.AllowedAddresses = []
132+
self.DeniedAddresses = []
133+
self.RequiredHeaders = []
134+
self.DeniedHeaders = []
135+
self.AllowPlainTextHttp = False
136+
self.DenyAllUnspecifiedIPs = False
137+
self.AddXFFHeader = True
138+
139+
self.add_denied_addresses(
140+
[
141+
# ==== IPv4 ==== #
142+
"255.255.255.255/32",
143+
"168.63.129.16/32", # Not nonroutable,
144+
# but this is the WireServer IP we should block.
145+
"192.0.0.0/24",
146+
"192.0.2.0/24",
147+
"192.88.99.0/24",
148+
"198.51.100.0/24",
149+
"203.0.113.0/24",
150+
"169.254.0.0/16",
151+
"192.168.0.0/16",
152+
"198.18.0.0/15",
153+
"172.16.0.0/12",
154+
"100.64.0.0/10", # IANA-Reserved
155+
"0.0.0.0/8",
156+
"10.0.0.0/8",
157+
"127.0.0.0/8",
158+
"25.0.0.0/8", # GNS Core
159+
"224.0.0.0/4",
160+
"240.0.0.0/4",
161+
# ==== IPv6 ==== #
162+
"::1/128", # Localhost
163+
"FC00::/7", # Unique-local
164+
"fe80::/10", # Link-local
165+
"fec0::/10", # Site-local
166+
"2001::/32", # Teredo
167+
]
168+
)
169+
self.DenyAllUnspecifiedIPs = False
170+
171+
# Deprecated method, for backward compatibility only
172+
def set_defaults(self):
173+
self._set_defaults()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import ipaddress
2+
from typing import Union
3+
4+
from .exceptions import AntiSSRFException
5+
6+
7+
class IPNetwork:
8+
def __init__(self, ip: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address], prefix: int) -> None:
9+
self._base_address = ipaddress.ip_network(f"{ip}/{prefix}", strict=False)
10+
self._prefix_length = prefix
11+
12+
def contains(self, ip: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]) -> bool:
13+
ip_obj = ip
14+
if isinstance(ip, str):
15+
ip_obj = ipaddress.ip_address(ip)
16+
# Convert IPv4 to IPv6-mapped if base is IPv6
17+
if self._base_address.version == 6 and ip_obj.version == 4:
18+
ip_obj = ipaddress.IPv6Address(f"::ffff:{ip_obj}")
19+
return ip_obj in self._base_address
20+
21+
22+
def _parse_ip_address(ip_string: str) -> Union[ipaddress.IPv4Address, ipaddress.IPv6Address]:
23+
"""Parse IP address from string, raising AntiSSRFException on error."""
24+
try:
25+
return ipaddress.ip_address(ip_string)
26+
except ValueError as e:
27+
raise AntiSSRFException("Bad CIDR", e)
28+
29+
30+
def _parse_prefix_length(prefix_string: str) -> int:
31+
"""Parse prefix length from string, raising AntiSSRFException on error."""
32+
try:
33+
return int(prefix_string)
34+
except ValueError as e:
35+
raise AntiSSRFException("Bad CIDR", e)
36+
37+
38+
def _create_single_ip_network(ip: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]) -> IPNetwork:
39+
"""Create network for single IP address (no prefix specified)."""
40+
if ip.version == 4:
41+
# IPv4 mapped to IPv6, /128
42+
return IPNetwork(f"::ffff:{ip}", 128)
43+
elif ip.version == 6:
44+
return IPNetwork(ip, 128)
45+
else:
46+
raise AntiSSRFException("Bad CIDR")
47+
48+
49+
def _create_prefixed_network(ip: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], prefix_length: int) -> IPNetwork:
50+
"""Create network for IP address with prefix."""
51+
if ip.version == 4 and 0 <= prefix_length <= 32:
52+
# IPv4 mapped to IPv6, prefix + 96
53+
return IPNetwork(f"::ffff:{ip}", prefix_length + 96)
54+
elif ip.version == 6 and 0 <= prefix_length <= 128:
55+
return IPNetwork(ip, prefix_length)
56+
else:
57+
raise AntiSSRFException("Bad CIDR")
58+
59+
60+
# Try parse CIDR string
61+
# Returns an IPNetwork object if everything went fine, or throws an exception
62+
# For easy computation of allow/deny, every IP Address is converted into an IPv6 address
63+
def try_parse_cidr_string(cidr_string: str) -> IPNetwork:
64+
parts = cidr_string.split("/")
65+
ip = _parse_ip_address(parts[0])
66+
67+
if len(parts) == 1:
68+
# e.g. "127.0.0.1" or "::ffff:909:909"
69+
return _create_single_ip_network(ip)
70+
elif len(parts) == 2:
71+
# Cases such as "127.0.0.1/2" or "::ffff:909:909/80"
72+
prefix_length = _parse_prefix_length(parts[1])
73+
return _create_prefixed_network(ip, prefix_length)
74+
else:
75+
raise AntiSSRFException("Bad CIDR")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class AntiSSRFException(Exception):
2+
def __init__(self, message=None, inner=None):
3+
if inner is not None:
4+
super().__init__(message, inner)
5+
elif message is not None:
6+
super().__init__(message)
7+
else:
8+
super().__init__()

src/promptflow-core/promptflow/_utils/multimedia_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from promptflow._constants import MessageFormatType
1414
from promptflow._utils._errors import InvalidImageInput, InvalidMessageFormatType, LoadMultimediaDataError
15+
from promptflow._utils.anti_ssrf import AntiSSRF, AntiSSRFException
1516
from promptflow._utils.yaml_utils import load_yaml
1617
from promptflow.contracts.flow import FlowInputDefinition
1718
from promptflow.contracts.multimedia import Image, PFBytes, Text
@@ -93,8 +94,30 @@ def create_image_from_base64(base64_str: str, mime_type: str = None):
9394
return Image(image_bytes, mime_type=mime_type)
9495

9596
@staticmethod
96-
def create_image_from_url(url: str, mime_type: str = None):
97-
response = requests.get(url)
97+
def create_image_from_url(url: str, mime_type: str = None) -> Image:
98+
anti_ssrf = AntiSSRF()
99+
anti_ssrf.policy.set_allow_plain_text_http(True)
100+
101+
def block_redirect_if_ssrf(response: requests.Response, *args, **kwargs) -> None:
102+
if not response.is_redirect:
103+
return
104+
105+
anti_ssrf.validate_url(response.headers["Location"])
106+
107+
try:
108+
anti_ssrf.validate_url(url)
109+
110+
# Use the requests "response" hook to allow us to inspect each response
111+
# in a redirect chain.
112+
# See: https://requests.readthedocs.io/en/latest/user/advanced/#event-hooks
113+
response = requests.get(url, hooks={"response": block_redirect_if_ssrf})
114+
except AntiSSRFException as e:
115+
raise InvalidImageInput(
116+
message_format="Failed to fetch image from URL: {url}.",
117+
target=ErrorTarget.EXECUTOR,
118+
url=url,
119+
) from e
120+
98121
if response.status_code == 200:
99122
return Image(response.content, mime_type=mime_type, source_url=url)
100123
else:

0 commit comments

Comments
 (0)