Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 3b754ae

Browse files
authored
Clean up caching/locking of OIDC metadata load (#9362)
Ensure that we lock correctly to prevent multiple concurrent metadata load requests, and generally clean up the way we construct the metadata cache.
1 parent 0ad0872 commit 3b754ae

File tree

5 files changed

+389
-62
lines changed

5 files changed

+389
-62
lines changed

changelog.d/9362.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Clean up the code to load the metadata for OpenID Connect identity providers.

synapse/handlers/oidc_handler.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from synapse.logging.context import make_deferred_yieldable
4242
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
4343
from synapse.util import json_decoder
44+
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
4445

4546
if TYPE_CHECKING:
4647
from synapse.server import HomeServer
@@ -245,6 +246,7 @@ def __init__(
245246

246247
self._token_generator = token_generator
247248

249+
self._config = provider
248250
self._callback_url = hs.config.oidc_callback_url # type: str
249251

250252
self._scopes = provider.scopes
@@ -253,14 +255,16 @@ def __init__(
253255
provider.client_id, provider.client_secret, provider.client_auth_method,
254256
) # type: ClientAuth
255257
self._client_auth_method = provider.client_auth_method
256-
self._provider_metadata = OpenIDProviderMetadata(
257-
issuer=provider.issuer,
258-
authorization_endpoint=provider.authorization_endpoint,
259-
token_endpoint=provider.token_endpoint,
260-
userinfo_endpoint=provider.userinfo_endpoint,
261-
jwks_uri=provider.jwks_uri,
262-
) # type: OpenIDProviderMetadata
263-
self._provider_needs_discovery = provider.discover
258+
259+
# cache of metadata for the identity provider (endpoint uris, mostly). This is
260+
# loaded on-demand from the discovery endpoint (if discovery is enabled), with
261+
# possible overrides from the config. Access via `load_metadata`.
262+
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
263+
264+
# cache of JWKs used by the identity provider to sign tokens. Loaded on demand
265+
# from the IdP's jwks_uri, if required.
266+
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
267+
264268
self._user_mapping_provider = provider.user_mapping_provider_class(
265269
provider.user_mapping_provider_config
266270
)
@@ -286,7 +290,7 @@ def __init__(
286290

287291
self._sso_handler.register_identity_provider(self)
288292

289-
def _validate_metadata(self):
293+
def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
290294
"""Verifies the provider metadata.
291295
292296
This checks the validity of the currently loaded provider. Not
@@ -305,7 +309,6 @@ def _validate_metadata(self):
305309
if self._skip_verification is True:
306310
return
307311

308-
m = self._provider_metadata
309312
m.validate_issuer()
310313
m.validate_authorization_endpoint()
311314
m.validate_token_endpoint()
@@ -340,11 +343,7 @@ def _validate_metadata(self):
340343
)
341344
else:
342345
# If we're not using userinfo, we need a valid jwks to validate the ID token
343-
if m.get("jwks") is None:
344-
if m.get("jwks_uri") is not None:
345-
m.validate_jwks_uri()
346-
else:
347-
raise ValueError('"jwks_uri" must be set')
346+
m.validate_jwks_uri()
348347

349348
@property
350349
def _uses_userinfo(self) -> bool:
@@ -361,30 +360,48 @@ def _uses_userinfo(self) -> bool:
361360
or self._user_profile_method == "userinfo_endpoint"
362361
)
363362

364-
async def load_metadata(self) -> OpenIDProviderMetadata:
365-
"""Load and validate the provider metadata.
363+
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
364+
"""Return the provider metadata.
365+
366+
If this is the first call, the metadata is built from the config and from the
367+
metadata discovery endpoint (if enabled), and then validated. If the metadata
368+
is successfully validated, it is then cached for future use.
366369
367-
The values metadatas are discovered if ``oidc_config.discovery`` is
368-
``True`` and then cached.
370+
Args:
371+
force: If true, any cached metadata is discarded to force a reload.
369372
370373
Raises:
371374
ValueError: if something in the provider is not valid
372375
373376
Returns:
374377
The provider's metadata.
375378
"""
376-
# If we are using the OpenID Discovery documents, it needs to be loaded once
377-
# FIXME: should there be a lock here?
378-
if self._provider_needs_discovery:
379-
url = get_well_known_url(self._provider_metadata["issuer"], external=True)
379+
if force:
380+
# reset the cached call to ensure we get a new result
381+
self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
382+
383+
return await self._provider_metadata.get()
384+
385+
async def _load_metadata(self) -> OpenIDProviderMetadata:
386+
# init the metadata from our config
387+
metadata = OpenIDProviderMetadata(
388+
issuer=self._config.issuer,
389+
authorization_endpoint=self._config.authorization_endpoint,
390+
token_endpoint=self._config.token_endpoint,
391+
userinfo_endpoint=self._config.userinfo_endpoint,
392+
jwks_uri=self._config.jwks_uri,
393+
)
394+
395+
# load any data from the discovery endpoint, if enabled
396+
if self._config.discover:
397+
url = get_well_known_url(self._config.issuer, external=True)
380398
metadata_response = await self._http_client.get_json(url)
381399
# TODO: maybe update the other way around to let user override some values?
382-
self._provider_metadata.update(metadata_response)
383-
self._provider_needs_discovery = False
400+
metadata.update(metadata_response)
384401

385-
self._validate_metadata()
402+
self._validate_metadata(metadata)
386403

387-
return self._provider_metadata
404+
return metadata
388405

389406
async def load_jwks(self, force: bool = False) -> JWKS:
390407
"""Load the JSON Web Key Set used to sign ID tokens.
@@ -414,27 +431,27 @@ async def load_jwks(self, force: bool = False) -> JWKS:
414431
]
415432
}
416433
"""
434+
if force:
435+
# reset the cached call to ensure we get a new result
436+
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
437+
return await self._jwks.get()
438+
439+
async def _load_jwks(self) -> JWKS:
417440
if self._uses_userinfo:
418441
# We're not using jwt signing, return an empty jwk set
419442
return {"keys": []}
420443

421-
# First check if the JWKS are loaded in the provider metadata.
422-
# It can happen either if the provider gives its JWKS in the discovery
423-
# document directly or if it was already loaded once.
424444
metadata = await self.load_metadata()
425-
jwk_set = metadata.get("jwks")
426-
if jwk_set is not None and not force:
427-
return jwk_set
428445

429-
# Loading the JWKS using the `jwks_uri` metadata
446+
# Load the JWKS using the `jwks_uri` metadata.
430447
uri = metadata.get("jwks_uri")
431448
if not uri:
449+
# this should be unreachable: load_metadata validates that
450+
# there is a jwks_uri in the metadata if _uses_userinfo is unset
432451
raise RuntimeError('Missing "jwks_uri" in metadata')
433452

434453
jwk_set = await self._http_client.get_json(uri)
435454

436-
# Caching the JWKS in the provider's metadata
437-
self._provider_metadata["jwks"] = jwk_set
438455
return jwk_set
439456

440457
async def _exchange_code(self, code: str) -> Token:

synapse/util/caches/cached_call.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2021 The Matrix.org Foundation C.I.C.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
17+
18+
from twisted.internet.defer import Deferred
19+
from twisted.python.failure import Failure
20+
21+
from synapse.logging.context import make_deferred_yieldable, run_in_background
22+
23+
TV = TypeVar("TV")
24+
25+
26+
class CachedCall(Generic[TV]):
27+
"""A wrapper for asynchronous calls whose results should be shared
28+
29+
This is useful for wrapping asynchronous functions, where there might be multiple
30+
callers, but we only want to call the underlying function once (and have the result
31+
returned to all callers).
32+
33+
Similar results can be achieved via a lock of some form, but that typically requires
34+
more boilerplate (and ends up being less efficient).
35+
36+
Correctly handles Synapse logcontexts (logs and resource usage for the underlying
37+
function are logged against the logcontext which is active when get() is first
38+
called).
39+
40+
Example usage:
41+
42+
_cached_val = CachedCall(_load_prop)
43+
44+
async def handle_request() -> X:
45+
# We can call this multiple times, but it will result in a single call to
46+
# _load_prop().
47+
return await _cached_val.get()
48+
49+
async def _load_prop() -> X:
50+
await difficult_operation()
51+
52+
53+
The implementation is deliberately single-shot (ie, once the call is initiated,
54+
there is no way to ask for it to be run). This keeps the implementation and
55+
semantics simple. If you want to make a new call, simply replace the whole
56+
CachedCall object.
57+
"""
58+
59+
__slots__ = ["_callable", "_deferred", "_result"]
60+
61+
def __init__(self, f: Callable[[], Awaitable[TV]]):
62+
"""
63+
Args:
64+
f: The underlying function. Only one call to this function will be alive
65+
at once (per instance of CachedCall)
66+
"""
67+
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
68+
self._deferred = None # type: Optional[Deferred]
69+
self._result = None # type: Union[None, Failure, TV]
70+
71+
async def get(self) -> TV:
72+
"""Kick off the call if necessary, and return the result"""
73+
74+
# Fire off the callable now if this is our first time
75+
if not self._deferred:
76+
self._deferred = run_in_background(self._callable)
77+
78+
# we will never need the callable again, so make sure it can be GCed
79+
self._callable = None
80+
81+
# once the deferred completes, store the result. We cannot simply leave the
82+
# result in the deferred, since if it's a Failure, GCing the deferred
83+
# would then log a critical error about unhandled Failures.
84+
def got_result(r):
85+
self._result = r
86+
87+
self._deferred.addBoth(got_result)
88+
89+
# TODO: consider cancellation semantics. Currently, if the call to get()
90+
# is cancelled, the underlying call will continue (and any future calls
91+
# will get the result/exception), which I think is *probably* ok, modulo
92+
# the fact the underlying call may be logged to a cancelled logcontext,
93+
# and any eventual exception may not be reported.
94+
95+
# we can now await the deferred, and once it completes, return the result.
96+
await make_deferred_yieldable(self._deferred)
97+
98+
# I *think* this is the easiest way to correctly raise a Failure without having
99+
# to gut-wrench into the implementation of Deferred.
100+
d = Deferred()
101+
d.callback(self._result)
102+
return await d
103+
104+
105+
class RetryOnExceptionCachedCall(Generic[TV]):
106+
"""A wrapper around CachedCall which will retry the call if an exception is thrown
107+
108+
This is used in much the same way as CachedCall, but adds some extra functionality
109+
so that if the underlying function throws an exception, then the next call to get()
110+
will initiate another call to the underlying function. (Any calls to get() which
111+
are already pending will raise the exception.)
112+
"""
113+
114+
slots = ["_cachedcall"]
115+
116+
def __init__(self, f: Callable[[], Awaitable[TV]]):
117+
async def _wrapper() -> TV:
118+
try:
119+
return await f()
120+
except Exception:
121+
# the call raised an exception: replace the underlying CachedCall to
122+
# trigger another call next time get() is called
123+
self._cachedcall = CachedCall(_wrapper)
124+
raise
125+
126+
self._cachedcall = CachedCall(_wrapper)
127+
128+
async def get(self) -> TV:
129+
return await self._cachedcall.get()

0 commit comments

Comments
 (0)