4141from synapse .logging .context import make_deferred_yieldable
4242from synapse .types import JsonDict , UserID , map_username_to_mxid_localpart
4343from synapse .util import json_decoder
44+ from synapse .util .caches .cached_call import RetryOnExceptionCachedCall
4445
4546if TYPE_CHECKING :
4647 from synapse .server import HomeServer
@@ -245,6 +246,7 @@ def __init__(
245246
246247 self ._token_generator = token_generator
247248
249+ self ._config = provider
248250 self ._callback_url = hs .config .oidc_callback_url # type: str
249251
250252 self ._scopes = provider .scopes
@@ -253,14 +255,16 @@ def __init__(
253255 provider .client_id , provider .client_secret , provider .client_auth_method ,
254256 ) # type: ClientAuth
255257 self ._client_auth_method = provider .client_auth_method
256- self ._provider_metadata = OpenIDProviderMetadata (
257- issuer = provider .issuer ,
258- authorization_endpoint = provider .authorization_endpoint ,
259- token_endpoint = provider .token_endpoint ,
260- userinfo_endpoint = provider .userinfo_endpoint ,
261- jwks_uri = provider .jwks_uri ,
262- ) # type: OpenIDProviderMetadata
263- self ._provider_needs_discovery = provider .discover
258+
259+ # cache of metadata for the identity provider (endpoint uris, mostly). This is
260+ # loaded on-demand from the discovery endpoint (if discovery is enabled), with
261+ # possible overrides from the config. Access via `load_metadata`.
262+ self ._provider_metadata = RetryOnExceptionCachedCall (self ._load_metadata )
263+
264+ # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
265+ # from the IdP's jwks_uri, if required.
266+ self ._jwks = RetryOnExceptionCachedCall (self ._load_jwks )
267+
264268 self ._user_mapping_provider = provider .user_mapping_provider_class (
265269 provider .user_mapping_provider_config
266270 )
@@ -286,7 +290,7 @@ def __init__(
286290
287291 self ._sso_handler .register_identity_provider (self )
288292
289- def _validate_metadata (self ) :
293+ def _validate_metadata (self , m : OpenIDProviderMetadata ) -> None :
290294 """Verifies the provider metadata.
291295
292296 This checks the validity of the currently loaded provider. Not
@@ -305,7 +309,6 @@ def _validate_metadata(self):
305309 if self ._skip_verification is True :
306310 return
307311
308- m = self ._provider_metadata
309312 m .validate_issuer ()
310313 m .validate_authorization_endpoint ()
311314 m .validate_token_endpoint ()
@@ -340,11 +343,7 @@ def _validate_metadata(self):
340343 )
341344 else :
342345 # If we're not using userinfo, we need a valid jwks to validate the ID token
343- if m .get ("jwks" ) is None :
344- if m .get ("jwks_uri" ) is not None :
345- m .validate_jwks_uri ()
346- else :
347- raise ValueError ('"jwks_uri" must be set' )
346+ m .validate_jwks_uri ()
348347
349348 @property
350349 def _uses_userinfo (self ) -> bool :
@@ -361,30 +360,48 @@ def _uses_userinfo(self) -> bool:
361360 or self ._user_profile_method == "userinfo_endpoint"
362361 )
363362
364- async def load_metadata (self ) -> OpenIDProviderMetadata :
365- """Load and validate the provider metadata.
363+ async def load_metadata (self , force : bool = False ) -> OpenIDProviderMetadata :
364+ """Return the provider metadata.
365+
366+ If this is the first call, the metadata is built from the config and from the
367+ metadata discovery endpoint (if enabled), and then validated. If the metadata
368+ is successfully validated, it is then cached for future use.
366369
367- The values metadatas are discovered if ``oidc_config.discovery`` is
368- ``True`` and then cached.
370+ Args:
371+ force: If true, any cached metadata is discarded to force a reload .
369372
370373 Raises:
371374 ValueError: if something in the provider is not valid
372375
373376 Returns:
374377 The provider's metadata.
375378 """
376- # If we are using the OpenID Discovery documents, it needs to be loaded once
377- # FIXME: should there be a lock here?
378- if self ._provider_needs_discovery :
379- url = get_well_known_url (self ._provider_metadata ["issuer" ], external = True )
379+ if force :
380+ # reset the cached call to ensure we get a new result
381+ self ._provider_metadata = RetryOnExceptionCachedCall (self ._load_metadata )
382+
383+ return await self ._provider_metadata .get ()
384+
385+ async def _load_metadata (self ) -> OpenIDProviderMetadata :
386+ # init the metadata from our config
387+ metadata = OpenIDProviderMetadata (
388+ issuer = self ._config .issuer ,
389+ authorization_endpoint = self ._config .authorization_endpoint ,
390+ token_endpoint = self ._config .token_endpoint ,
391+ userinfo_endpoint = self ._config .userinfo_endpoint ,
392+ jwks_uri = self ._config .jwks_uri ,
393+ )
394+
395+ # load any data from the discovery endpoint, if enabled
396+ if self ._config .discover :
397+ url = get_well_known_url (self ._config .issuer , external = True )
380398 metadata_response = await self ._http_client .get_json (url )
381399 # TODO: maybe update the other way around to let user override some values?
382- self ._provider_metadata .update (metadata_response )
383- self ._provider_needs_discovery = False
400+ metadata .update (metadata_response )
384401
385- self ._validate_metadata ()
402+ self ._validate_metadata (metadata )
386403
387- return self . _provider_metadata
404+ return metadata
388405
389406 async def load_jwks (self , force : bool = False ) -> JWKS :
390407 """Load the JSON Web Key Set used to sign ID tokens.
@@ -414,27 +431,27 @@ async def load_jwks(self, force: bool = False) -> JWKS:
414431 ]
415432 }
416433 """
434+ if force :
435+ # reset the cached call to ensure we get a new result
436+ self ._jwks = RetryOnExceptionCachedCall (self ._load_jwks )
437+ return await self ._jwks .get ()
438+
439+ async def _load_jwks (self ) -> JWKS :
417440 if self ._uses_userinfo :
418441 # We're not using jwt signing, return an empty jwk set
419442 return {"keys" : []}
420443
421- # First check if the JWKS are loaded in the provider metadata.
422- # It can happen either if the provider gives its JWKS in the discovery
423- # document directly or if it was already loaded once.
424444 metadata = await self .load_metadata ()
425- jwk_set = metadata .get ("jwks" )
426- if jwk_set is not None and not force :
427- return jwk_set
428445
429- # Loading the JWKS using the `jwks_uri` metadata
446+ # Load the JWKS using the `jwks_uri` metadata.
430447 uri = metadata .get ("jwks_uri" )
431448 if not uri :
449+ # this should be unreachable: load_metadata validates that
450+ # there is a jwks_uri in the metadata if _uses_userinfo is unset
432451 raise RuntimeError ('Missing "jwks_uri" in metadata' )
433452
434453 jwk_set = await self ._http_client .get_json (uri )
435454
436- # Caching the JWKS in the provider's metadata
437- self ._provider_metadata ["jwks" ] = jwk_set
438455 return jwk_set
439456
440457 async def _exchange_code (self , code : str ) -> Token :
0 commit comments