11import logging
2+ import re
23from abc import abstractmethod
34from pathlib import Path
4- from urllib .parse import urljoin
5+ from urllib .parse import urljoin , urlparse
56
67import httpx
78from cryptography .hazmat .primitives .asymmetric .ec import EllipticCurvePublicKey
@@ -34,11 +35,16 @@ def key_resolver_from_client_database(client_database: str, key_cache: KeyCache
3435class KeyResolver :
3536 def __init__ (self ):
3637 self .logger = logging .getLogger (__name__ ).getChild (self .__class__ .__name__ )
38+ self .key_id_validator = re .compile (r"^[a-zA-Z0-9_-]+$" )
3739
3840 @abstractmethod
3941 def resolve_public_key (self , key_id : str ) -> PublicKey :
4042 pass
4143
44+ def validate_key_id (self , key_id : str ) -> None :
45+ if not self .key_id_validator .match (key_id ):
46+ raise ValueError (f"Invalid key_id format: { key_id } " )
47+
4248
4349class CacheKeyResolver (KeyResolver ):
4450 def __init__ (self , key_cache : KeyCache | None ):
@@ -69,6 +75,7 @@ def __init__(self, client_database_directory: str, key_cache: KeyCache | None =
6975
7076 def get_public_key_pem (self , key_id : str ) -> bytes :
7177 with tracer .start_as_current_span ("get_public_key_pem_from_file" ):
78+ self .validate_key_id (key_id )
7279 filename = Path (self .client_database_directory ) / f"{ key_id } .pem"
7380 self .logger .debug ("Fetching public key for %s from %s" , key_id , filename )
7481 try :
@@ -87,10 +94,16 @@ def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = N
8794
8895 def get_public_key_pem (self , key_id : str ) -> bytes :
8996 with tracer .start_as_current_span ("get_public_key_pem_from_url" ):
97+ self .validate_key_id (key_id )
98+
9099 if self .key_id_pattern in self .client_database_base_url :
91100 public_key_url = self .client_database_base_url .replace (self .key_id_pattern , key_id )
92101 else :
93102 public_key_url = urljoin (self .client_database_base_url , f"{ key_id } .pem" )
103+
104+ if urlparse (public_key_url ).scheme not in ("http" , "https" ):
105+ raise ValueError (f"Invalid URL constructed: { public_key_url } " )
106+
94107 self .logger .debug ("Fetching public key for %s from %s" , key_id , public_key_url )
95108 try :
96109 response = self .httpx_client .get (public_key_url )
0 commit comments