Skip to content

Commit 49a1552

Browse files
committed
fix: make Clerk JWKS fetching async to prevent event loop blocking
1 parent 95c31b8 commit 49a1552

File tree

1 file changed

+105
-1
lines changed

1 file changed

+105
-1
lines changed

commitly-backend/app/core/auth.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,22 @@ def __init__(self, jwks_url: str, ttl_seconds: int) -> None:
5555
self._expires_at: float = 0.0
5656
self._lock = RLock()
5757

58+
async def _fetch_async(self) -> Dict[str, Any]:
59+
"""Fetch JWKS using async HTTP client."""
60+
try:
61+
async with httpx.AsyncClient(timeout=5.0) as client:
62+
response = await client.get(str(self.jwks_url))
63+
response.raise_for_status()
64+
payload = response.json()
65+
except httpx.HTTPError as exc: # pragma: no cover - network failure
66+
raise InvalidClerkToken("Failed to download Clerk JWKS") from exc
67+
68+
if not isinstance(payload, dict) or "keys" not in payload:
69+
raise InvalidClerkToken("JWKS payload is missing keys")
70+
return payload
71+
5872
def _fetch(self) -> Dict[str, Any]:
73+
"""Sync fetch for backwards compatibility."""
5974
try:
6075
with httpx.Client(timeout=5.0) as client:
6176
response = client.get(str(self.jwks_url))
@@ -68,6 +83,21 @@ def _fetch(self) -> Dict[str, Any]:
6883
raise InvalidClerkToken("JWKS payload is missing keys")
6984
return payload
7085

86+
async def _current_jwks_async(self) -> Dict[str, Any]:
87+
"""Async version that fetches JWKS if needed."""
88+
with self._lock:
89+
now = time.monotonic()
90+
if self._jwks and now < self._expires_at:
91+
return self._jwks
92+
93+
# Fetch outside the lock to avoid blocking
94+
fresh_jwks = await self._fetch_async()
95+
96+
with self._lock:
97+
self._jwks = fresh_jwks
98+
self._expires_at = time.monotonic() + max(self.ttl_seconds, 60)
99+
return self._jwks
100+
71101
def _current_jwks(self) -> Dict[str, Any]:
72102
with self._lock:
73103
now = time.monotonic()
@@ -77,6 +107,14 @@ def _current_jwks(self) -> Dict[str, Any]:
77107
self._expires_at = now + max(self.ttl_seconds, 60)
78108
return self._jwks
79109

110+
async def get_key_async(self, kid: str) -> Dict[str, Any]:
111+
"""Async version of get_key."""
112+
jwks = await self._current_jwks_async()
113+
for key in jwks.get("keys", []):
114+
if key.get("kid") == kid:
115+
return cast(Dict[str, Any], key)
116+
raise InvalidClerkToken("No matching JWK for supplied token")
117+
80118
def get_key(self, kid: str) -> Dict[str, Any]:
81119
jwks = self._current_jwks()
82120
for key in jwks.get("keys", []):
@@ -180,6 +218,71 @@ def verify_clerk_token(token: str) -> ClerkClaims:
180218
return cast(ClerkClaims, claims)
181219

182220

221+
async def verify_clerk_token_async(token: str) -> ClerkClaims:
222+
"""Async version of verify_clerk_token that doesn't block the event loop."""
223+
try:
224+
header = jwt.get_unverified_header(token)
225+
except JWTError as exc:
226+
raise InvalidClerkToken("Malformed token header") from exc
227+
228+
kid = header.get("kid")
229+
if not isinstance(kid, str):
230+
raise InvalidClerkToken("Missing key identifier")
231+
232+
# Use async version to avoid blocking
233+
jwk_data = await jwks_cache.get_key_async(kid)
234+
public_key = jwk.construct(jwk_data)
235+
236+
try:
237+
message, encoded_signature = token.rsplit(".", 1)
238+
except ValueError as exc:
239+
raise InvalidClerkToken("Token structure is invalid") from exc
240+
241+
decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
242+
if not public_key.verify(message.encode("utf-8"), decoded_signature):
243+
raise InvalidClerkToken("Signature verification failed")
244+
245+
claims = jwt.get_unverified_claims(token)
246+
now = int(time.time())
247+
248+
exp = claims.get("exp")
249+
if exp is not None and int(exp) <= now:
250+
raise InvalidClerkToken("Token has expired")
251+
252+
nbf = claims.get("nbf")
253+
if nbf is not None and now < int(nbf):
254+
raise InvalidClerkToken("Token is not yet valid")
255+
256+
issuer = claims.get("iss")
257+
if issuer != settings.clerk_issuer:
258+
raise InvalidClerkToken("Invalid issuer")
259+
260+
audience_values = _select_audience(claims.get("aud"))
261+
allowed_audiences = Settings._coerce_list(settings.clerk_audience) or [
262+
settings.clerk_audience
263+
]
264+
if audience_values:
265+
if not any(audience in audience_values for audience in allowed_audiences):
266+
raise InvalidClerkToken("Invalid audience")
267+
268+
if settings.clerk_authorized_parties:
269+
azp = claims.get("azp")
270+
if isinstance(azp, str):
271+
normalized_azp = _normalize_party(azp)
272+
allowed = {
273+
_normalize_party(party) for party in settings.clerk_authorized_parties
274+
}
275+
if "*" not in allowed and normalized_azp not in allowed:
276+
raise InvalidClerkToken("Token not issued for this application")
277+
else:
278+
raise InvalidClerkToken("Token missing authorized party")
279+
280+
if "sub" not in claims:
281+
raise InvalidClerkToken("Token is missing subject claim")
282+
283+
return cast(ClerkClaims, claims)
284+
285+
183286
def _unauthorized(detail: str) -> HTTPException:
184287
return HTTPException(
185288
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -254,7 +357,8 @@ async def dispatch(
254357

255358
if token:
256359
try:
257-
request.state.clerk_claims = verify_clerk_token(token)
360+
# Use async version to avoid blocking the event loop
361+
request.state.clerk_claims = await verify_clerk_token_async(token)
258362
except InvalidClerkToken as exc:
259363
request.state.clerk_auth_error = exc
260364

0 commit comments

Comments
 (0)