Skip to content

Commit 4bbce0f

Browse files
committed
Ensure key identifiers are sane and verify URLs
1 parent b68e81d commit 4bbce0f

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

dnstapir/key_resolver.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2+
import re
23
from abc import abstractmethod
34
from pathlib import Path
4-
from urllib.parse import urljoin
5+
from urllib.parse import urljoin, urlparse
56

67
import httpx
78
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
@@ -34,11 +35,16 @@ def key_resolver_from_client_database(client_database: str, key_cache: KeyCache
3435
class KeyResolver:
3536
def __init__(self):
3637
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
38+
self.key_id_validator = re.compile(r"^[a-zA-Z0-9_-]+$")
3739

3840
@abstractmethod
3941
def resolve_public_key(self, key_id: str) -> PublicKey:
4042
pass
4143

44+
def validate_key_id(self, key_id: str) -> None:
45+
if not self.key_id_validator.match(key_id):
46+
raise ValueError(f"Invalid key_id format: {key_id}")
47+
4248

4349
class CacheKeyResolver(KeyResolver):
4450
def __init__(self, key_cache: KeyCache | None):
@@ -69,6 +75,7 @@ def __init__(self, client_database_directory: str, key_cache: KeyCache | None =
6975

7076
def get_public_key_pem(self, key_id: str) -> bytes:
7177
with tracer.start_as_current_span("get_public_key_pem_from_file"):
78+
self.validate_key_id(key_id)
7279
filename = Path(self.client_database_directory) / f"{key_id}.pem"
7380
self.logger.debug("Fetching public key for %s from %s", key_id, filename)
7481
try:
@@ -87,10 +94,16 @@ def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = N
8794

8895
def get_public_key_pem(self, key_id: str) -> bytes:
8996
with tracer.start_as_current_span("get_public_key_pem_from_url"):
97+
self.validate_key_id(key_id)
98+
9099
if self.key_id_pattern in self.client_database_base_url:
91100
public_key_url = self.client_database_base_url.replace(self.key_id_pattern, key_id)
92101
else:
93102
public_key_url = urljoin(self.client_database_base_url, f"{key_id}.pem")
103+
104+
if urlparse(public_key_url).scheme not in ("http", "https"):
105+
raise ValueError(f"Invalid URL constructed: {public_key_url}")
106+
94107
self.logger.debug("Fetching public key for %s from %s", key_id, public_key_url)
95108
try:
96109
response = self.httpx_client.get(public_key_url)

tests/test_key_resolver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def test_url_key_resolver(httpx_mock: HTTPXMock):
4646
request = httpx_mock.get_request()
4747
assert request.headers["Accept"] == "application/x-pem-file"
4848

49+
with pytest.raises(ValueError):
50+
_ = resolver.resolve_public_key("🔐")
51+
4952
with pytest.raises(KeyError):
5053
_ = resolver.resolve_public_key("unknown")
5154

0 commit comments

Comments
 (0)