Skip to content

Commit 08e1c7e

Browse files
committed
fix: implement KAS allowlist functionality
1 parent f8406bf commit 08e1c7e

File tree

6 files changed

+706
-24
lines changed

6 files changed

+706
-24
lines changed

src/otdf_python/cli.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,33 +112,14 @@ def load_client_credentials(creds_file_path: str) -> tuple[str, str]:
112112
) from e
113113

114114

115-
def build_sdk(args) -> SDK:
116-
"""Build SDK instance from CLI arguments."""
117-
builder = SDKBuilder()
118-
119-
if args.platform_url:
120-
builder.set_platform_endpoint(args.platform_url)
121-
122-
# Auto-detect HTTP URLs and enable plaintext mode
123-
if args.platform_url.startswith("http://") and (
124-
not hasattr(args, "plaintext") or not args.plaintext
125-
):
126-
logger.debug(
127-
f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode"
128-
)
129-
builder.use_insecure_plaintext_connection(True)
130-
131-
if args.oidc_endpoint:
132-
builder.set_issuer_endpoint(args.oidc_endpoint)
133-
115+
def _configure_auth(builder: SDKBuilder, args) -> None:
116+
"""Configure authentication on the SDK builder."""
134117
if args.client_id and args.client_secret:
135118
builder.client_secret(args.client_id, args.client_secret)
136119
elif hasattr(args, "with_client_creds_file") and args.with_client_creds_file:
137-
# Load credentials from file
138120
client_id, client_secret = load_client_credentials(args.with_client_creds_file)
139121
builder.client_secret(client_id, client_secret)
140122
elif hasattr(args, "auth") and args.auth:
141-
# Parse combined auth string (clientId:clientSecret) - legacy support
142123
auth_parts = args.auth.split(":")
143124
if len(auth_parts) != 2:
144125
raise CLIError(
@@ -152,12 +133,49 @@ def build_sdk(args) -> SDK:
152133
"Authentication required: provide --with-client-creds-file OR --client-id and --client-secret",
153134
)
154135

136+
137+
def _configure_kas_allowlist(builder: SDKBuilder, args) -> None:
138+
"""Configure KAS allowlist on the SDK builder."""
139+
if hasattr(args, "ignore_kas_allowlist") and args.ignore_kas_allowlist:
140+
logger.warning(
141+
"KAS allowlist validation is disabled. This may leak credentials "
142+
"to malicious servers if decrypting untrusted TDF files."
143+
)
144+
builder.with_ignore_kas_allowlist(True)
145+
elif hasattr(args, "kas_allowlist") and args.kas_allowlist:
146+
kas_urls = [url.strip() for url in args.kas_allowlist.split(",") if url.strip()]
147+
logger.debug(f"Using KAS allowlist: {kas_urls}")
148+
builder.with_kas_allowlist(kas_urls)
149+
150+
151+
def build_sdk(args) -> SDK:
152+
"""Build SDK instance from CLI arguments."""
153+
builder = SDKBuilder()
154+
155+
if args.platform_url:
156+
builder.set_platform_endpoint(args.platform_url)
157+
# Auto-detect HTTP URLs and enable plaintext mode
158+
if args.platform_url.startswith("http://") and (
159+
not hasattr(args, "plaintext") or not args.plaintext
160+
):
161+
logger.debug(
162+
f"Auto-detected HTTP URL {args.platform_url}, enabling plaintext mode"
163+
)
164+
builder.use_insecure_plaintext_connection(True)
165+
166+
if args.oidc_endpoint:
167+
builder.set_issuer_endpoint(args.oidc_endpoint)
168+
169+
_configure_auth(builder, args)
170+
155171
if hasattr(args, "plaintext") and args.plaintext:
156172
builder.use_insecure_plaintext_connection(True)
157173

158174
if args.insecure:
159175
builder.use_insecure_skip_verify(True)
160176

177+
_configure_kas_allowlist(builder, args)
178+
161179
return builder.build()
162180

163181

@@ -476,6 +494,17 @@ def create_parser() -> argparse.ArgumentParser:
476494
security_group.add_argument(
477495
"--insecure", action="store_true", help="Skip TLS verification"
478496
)
497+
security_group.add_argument(
498+
"--kas-allowlist",
499+
help="Comma-separated list of trusted KAS URLs. "
500+
"By default, only the platform URL's KAS endpoint is trusted.",
501+
)
502+
security_group.add_argument(
503+
"--ignore-kas-allowlist",
504+
action="store_true",
505+
help="WARNING: Disable KAS allowlist validation. This is insecure and "
506+
"should only be used for testing. May leak credentials to malicious servers.",
507+
)
479508

480509
# Subcommands
481510
subparsers = parser.add_subparsers(dest="command", help="Available commands")

src/otdf_python/kas_allowlist.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""KAS Allowlist: Validates KAS URLs against a list of trusted hosts.
2+
3+
This module provides protection against SSRF attacks where malicious TDF files
4+
could contain attacker-controlled KAS URLs to steal OIDC credentials.
5+
"""
6+
7+
import logging
8+
from urllib.parse import urlparse
9+
10+
11+
class KASAllowlist:
12+
"""Validates KAS URLs against an allowlist of trusted hosts.
13+
14+
This class prevents credential theft by ensuring the SDK only sends
15+
authentication tokens to trusted KAS endpoints.
16+
17+
Example:
18+
allowlist = KASAllowlist(["https://kas.example.com"])
19+
allowlist.is_allowed("https://kas.example.com/kas") # True
20+
allowlist.is_allowed("https://evil.com/kas") # False
21+
22+
"""
23+
24+
def __init__(self, allowed_urls: list[str] | None = None, allow_all: bool = False):
25+
"""Initialize the KAS allowlist.
26+
27+
Args:
28+
allowed_urls: List of trusted KAS URLs. Each URL is normalized to
29+
its origin (scheme://host:port) for comparison.
30+
allow_all: If True, all URLs are allowed. Use only for testing.
31+
A warning is logged when this is enabled.
32+
33+
"""
34+
self._allowed_origins: set[str] = set()
35+
self._allow_all = allow_all
36+
37+
if allow_all:
38+
logging.warning(
39+
"KAS allowlist is disabled (allow_all=True). "
40+
"This is insecure and should only be used for testing."
41+
)
42+
43+
if allowed_urls:
44+
for url in allowed_urls:
45+
self.add(url)
46+
47+
def add(self, url: str) -> None:
48+
"""Add a URL to the allowlist.
49+
50+
The URL is normalized to its origin (scheme://host:port) before storage.
51+
Paths and query strings are stripped.
52+
53+
Args:
54+
url: The KAS URL to allow. Can include path components which
55+
will be stripped for origin comparison.
56+
57+
"""
58+
origin = self._get_origin(url)
59+
self._allowed_origins.add(origin)
60+
logging.debug(f"Added KAS origin to allowlist: {origin}")
61+
62+
def is_allowed(self, url: str) -> bool:
63+
"""Check if a URL is allowed by the allowlist.
64+
65+
Args:
66+
url: The KAS URL to check.
67+
68+
Returns:
69+
True if the URL's origin is in the allowlist or allow_all is True.
70+
False otherwise.
71+
72+
"""
73+
if self._allow_all:
74+
logging.debug(f"KAS URL allowed (allow_all=True): {url}")
75+
return True
76+
77+
if not self._allowed_origins:
78+
logging.debug(f"KAS URL rejected (empty allowlist): {url}")
79+
return False
80+
81+
origin = self._get_origin(url)
82+
allowed = origin in self._allowed_origins
83+
if allowed:
84+
logging.debug(f"KAS URL allowed: {url} (origin: {origin})")
85+
else:
86+
logging.debug(
87+
f"KAS URL rejected: {url} (origin: {origin}, "
88+
f"allowed: {self._allowed_origins})"
89+
)
90+
return allowed
91+
92+
def validate(self, url: str) -> None:
93+
"""Validate a URL against the allowlist, raising an exception if not allowed.
94+
95+
Args:
96+
url: The KAS URL to validate.
97+
98+
Raises:
99+
SDK.KasAllowlistException: If the URL is not in the allowlist.
100+
101+
"""
102+
if not self.is_allowed(url):
103+
# Import here to avoid circular imports
104+
from .sdk import SDK
105+
106+
raise SDK.KasAllowlistException(url, self._allowed_origins)
107+
108+
@property
109+
def allowed_origins(self) -> set[str]:
110+
"""Return the set of allowed origins (read-only copy)."""
111+
return self._allowed_origins.copy()
112+
113+
@property
114+
def allow_all(self) -> bool:
115+
"""Return whether all URLs are allowed."""
116+
return self._allow_all
117+
118+
@staticmethod
119+
def _get_origin(url: str) -> str:
120+
"""Extract the origin (scheme://host:port) from a URL.
121+
122+
This normalizes URLs for comparison by stripping paths and query strings.
123+
Default ports (80 for http, 443 for https) are included explicitly.
124+
125+
Args:
126+
url: The URL to extract the origin from.
127+
128+
Returns:
129+
Normalized origin string in format scheme://host:port
130+
131+
"""
132+
# Add scheme if missing
133+
if "://" not in url:
134+
url = "https://" + url
135+
136+
try:
137+
parsed = urlparse(url)
138+
except Exception as e:
139+
logging.warning(f"Failed to parse URL {url}: {e}")
140+
# Return the URL as-is if parsing fails
141+
return url.lower()
142+
143+
scheme = (parsed.scheme or "https").lower()
144+
hostname = (parsed.hostname or "").lower()
145+
146+
if not hostname:
147+
# URL might be malformed, return as-is
148+
logging.warning(f"Could not extract hostname from URL: {url}")
149+
return url.lower()
150+
151+
# Determine port (use explicit port or default for scheme)
152+
if parsed.port:
153+
port = parsed.port
154+
elif scheme == "http":
155+
port = 80
156+
else:
157+
port = 443
158+
159+
return f"{scheme}://{hostname}:{port}"
160+
161+
@classmethod
162+
def from_platform_url(cls, platform_url: str) -> "KASAllowlist":
163+
"""Create an allowlist from a platform URL.
164+
165+
This is the default behavior: auto-allow the platform's KAS endpoint.
166+
167+
Args:
168+
platform_url: The OpenTDF platform URL. The KAS endpoint is
169+
assumed to be at {platform_url}/kas.
170+
171+
Returns:
172+
KASAllowlist configured to allow the platform's KAS endpoint.
173+
174+
"""
175+
allowlist = cls()
176+
# Add the platform URL itself (KAS might be at root or /kas)
177+
allowlist.add(platform_url)
178+
# Also construct the /kas endpoint explicitly
179+
kas_url = platform_url.rstrip("/") + "/kas"
180+
allowlist.add(kas_url)
181+
logging.info(f"Created KAS allowlist from platform URL: {platform_url}")
182+
return allowlist

src/otdf_python/kas_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,26 @@ def __init__(
3838
cache=None,
3939
use_plaintext=False,
4040
verify_ssl=True,
41+
kas_allowlist=None,
4142
):
42-
"""Initialize KAS client."""
43+
"""Initialize KAS client.
44+
45+
Args:
46+
kas_url: Default KAS URL
47+
token_source: Function that returns an authentication token
48+
cache: Optional KASKeyCache for caching public keys
49+
use_plaintext: Whether to use HTTP instead of HTTPS
50+
verify_ssl: Whether to verify SSL certificates
51+
kas_allowlist: Optional KASAllowlist for URL validation. If provided,
52+
only URLs in the allowlist will be contacted.
53+
54+
"""
4355
self.kas_url = kas_url
4456
self.token_source = token_source
4557
self.cache = cache or KASKeyCache()
4658
self.use_plaintext = use_plaintext
4759
self.verify_ssl = verify_ssl
60+
self.kas_allowlist = kas_allowlist
4861
self.decryptor = None
4962
self.client_public_key = None
5063

@@ -86,15 +99,26 @@ def close(self):
8699
def _normalize_kas_url(self, url: str) -> str:
87100
"""Normalize KAS URLs based on client security settings.
88101
102+
This method also validates the URL against the KAS allowlist if one
103+
is configured. This prevents SSRF attacks where malicious TDF files
104+
could contain attacker-controlled KAS URLs to steal OIDC credentials.
105+
89106
Args:
90107
url: The KAS URL to normalize
91108
92109
Returns:
93110
Normalized URL with appropriate protocol and port
94111
112+
Raises:
113+
KASAllowlistException: If the URL is not in the allowlist
114+
95115
"""
96116
from urllib.parse import urlparse
97117

118+
# Validate against allowlist BEFORE making any requests
119+
if self.kas_allowlist is not None:
120+
self.kas_allowlist.validate(url)
121+
98122
try:
99123
# Parse the URL
100124
parsed = urlparse(url)

src/otdf_python/sdk.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
token_source=None,
3838
sdk_ssl_verify=True,
3939
use_plaintext=False,
40+
kas_allowlist=None,
4041
):
4142
"""Initialize the KAS client.
4243
@@ -45,6 +46,7 @@ def __init__(
4546
token_source: Function that returns an authentication token
4647
sdk_ssl_verify: Whether to verify SSL certificates
4748
use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS
49+
kas_allowlist: Optional KASAllowlist for URL validation
4850
4951
"""
5052
from .kas_client import KASClient
@@ -54,6 +56,7 @@ def __init__(
5456
token_source=token_source,
5557
verify_ssl=sdk_ssl_verify,
5658
use_plaintext=use_plaintext,
59+
kas_allowlist=kas_allowlist,
5760
)
5861
# Store the parameters for potential use
5962
self._sdk_ssl_verify = sdk_ssl_verify
@@ -405,6 +408,33 @@ class KasBadRequestException(SDKException):
405408
class KasAllowlistException(SDKException):
406409
"""Throw when KAS allowlist check fails."""
407410

411+
def __init__(
412+
self,
413+
url: str,
414+
allowed_origins: set[str] | None = None,
415+
message: str | None = None,
416+
):
417+
"""Initialize exception.
418+
419+
Args:
420+
url: The KAS URL that was rejected
421+
allowed_origins: Set of allowed origin URLs
422+
message: Optional custom message (auto-generated if not provided)
423+
424+
"""
425+
self.url = url
426+
self.allowed_origins = allowed_origins or set()
427+
if message is None:
428+
origins_str = (
429+
", ".join(sorted(self.allowed_origins))
430+
if self.allowed_origins
431+
else "none"
432+
)
433+
message = (
434+
f"KAS URL not in allowlist: {url}. Allowed origins: {origins_str}"
435+
)
436+
super().__init__(message)
437+
408438
class AssertionException(SDKException):
409439
"""Throw when an assertion validation fails."""
410440

0 commit comments

Comments
 (0)