2121
2222import abc
2323import logging
24+ from contextlib import ExitStack
2425from typing import TYPE_CHECKING , Callable , Iterable
2526
2627import attr
@@ -150,57 +151,81 @@ class Keyring:
150151 """
151152
152153 def __init__ (
153- self , hs : "HomeServer" , key_fetchers : "Iterable[KeyFetcher] | None" = None
154+ self ,
155+ hs : "HomeServer" ,
156+ test_only_key_fetchers : "list[KeyFetcher] | None" = None ,
154157 ):
155- self .server_name = hs .hostname
158+ """
159+ Args:
160+ hs: The HomeServer instance
161+ test_only_key_fetchers: Dependency injection for tests only. If provided,
162+ these key fetchers will be used instead of the default ones.
163+ """
164+ # Clean-up to avoid partial initialization leaving behind references.
165+ with ExitStack () as exit :
166+ self .server_name = hs .hostname
167+
168+ self ._key_fetchers : list [KeyFetcher ] = []
169+ if test_only_key_fetchers is None :
170+ # Always fetch keys from the database.
171+ store_key_fetcher = StoreKeyFetcher (hs )
172+ exit .callback (store_key_fetcher .shutdown )
173+ self ._key_fetchers .append (store_key_fetcher )
174+
175+ # Fetch keys from configured trusted key servers, if any exist.
176+ key_servers = hs .config .key .key_servers
177+ if key_servers :
178+ perspectives_key_fetcher = PerspectivesKeyFetcher (hs )
179+ exit .callback (perspectives_key_fetcher .shutdown )
180+ self ._key_fetchers .append (perspectives_key_fetcher )
181+
182+ # Finally, fetch keys from the origin server directly.
183+ server_key_fetcher = ServerKeyFetcher (hs )
184+ exit .callback (server_key_fetcher .shutdown )
185+ self ._key_fetchers .append (server_key_fetcher )
186+ else :
187+ self ._key_fetchers = test_only_key_fetchers
188+
189+ self ._fetch_keys_queue : BatchingQueue [
190+ _FetchKeyRequest , dict [str , dict [str , FetchKeyResult ]]
191+ ] = BatchingQueue (
192+ name = "keyring_server" ,
193+ hs = hs ,
194+ clock = hs .get_clock (),
195+ # The method called to fetch each key
196+ process_batch_callback = self ._inner_fetch_key_requests ,
197+ )
198+ exit .callback (self ._fetch_keys_queue .shutdown )
156199
157- if key_fetchers is None :
158- # Always fetch keys from the database.
159- mutable_key_fetchers : list [KeyFetcher ] = [StoreKeyFetcher (hs )]
160- # Fetch keys from configured trusted key servers, if any exist.
161- key_servers = hs .config .key .key_servers
162- if key_servers :
163- mutable_key_fetchers .append (PerspectivesKeyFetcher (hs ))
164- # Finally, fetch keys from the origin server directly.
165- mutable_key_fetchers .append (ServerKeyFetcher (hs ))
166-
167- self ._key_fetchers : Iterable [KeyFetcher ] = tuple (mutable_key_fetchers )
168- else :
169- self ._key_fetchers = key_fetchers
170-
171- self ._fetch_keys_queue : BatchingQueue [
172- _FetchKeyRequest , dict [str , dict [str , FetchKeyResult ]]
173- ] = BatchingQueue (
174- name = "keyring_server" ,
175- hs = hs ,
176- clock = hs .get_clock (),
177- # The method called to fetch each key
178- process_batch_callback = self ._inner_fetch_key_requests ,
179- )
200+ self ._is_mine_server_name = hs .is_mine_server_name
180201
181- self ._is_mine_server_name = hs .is_mine_server_name
202+ # build a FetchKeyResult for each of our own keys, to shortcircuit the
203+ # fetcher.
204+ self ._local_verify_keys : dict [str , FetchKeyResult ] = {}
205+ for key_id , key in hs .config .key .old_signing_keys .items ():
206+ self ._local_verify_keys [key_id ] = FetchKeyResult (
207+ verify_key = key , valid_until_ts = key .expired
208+ )
182209
183- # build a FetchKeyResult for each of our own keys, to shortcircuit the
184- # fetcher.
185- self ._local_verify_keys : dict [str , FetchKeyResult ] = {}
186- for key_id , key in hs .config .key .old_signing_keys .items ():
187- self ._local_verify_keys [key_id ] = FetchKeyResult (
188- verify_key = key , valid_until_ts = key .expired
210+ vk = get_verify_key (hs .signing_key )
211+ self ._local_verify_keys [f"{ vk .alg } :{ vk .version } " ] = FetchKeyResult (
212+ verify_key = vk ,
213+ valid_until_ts = 2 ** 63 , # fake future timestamp
189214 )
190215
191- vk = get_verify_key (hs .signing_key )
192- self ._local_verify_keys [f"{ vk .alg } :{ vk .version } " ] = FetchKeyResult (
193- verify_key = vk ,
194- valid_until_ts = 2 ** 63 , # fake future timestamp
195- )
216+ # We reached the end of the block which means everything was successful, so
217+ # no exit handlers are needed (remove them all).
218+ exit .pop_all ()
196219
197220 def shutdown (self ) -> None :
198221 """
199222 Prepares the KeyRing for garbage collection by shutting down it's queues.
200223 """
201224 self ._fetch_keys_queue .shutdown ()
225+
202226 for key_fetcher in self ._key_fetchers :
203227 key_fetcher .shutdown ()
228+ self ._key_fetchers .clear ()
204229
205230 async def verify_json_for_server (
206231 self ,
@@ -521,9 +546,21 @@ class StoreKeyFetcher(KeyFetcher):
521546 """KeyFetcher impl which fetches keys from our data store"""
522547
523548 def __init__ (self , hs : "HomeServer" ):
524- super ().__init__ (hs )
525-
526- self .store = hs .get_datastores ().main
549+ # Clean-up to avoid partial initialization leaving behind references.
550+ with ExitStack () as exit :
551+ super ().__init__ (hs )
552+ # `KeyFetcher` keeps a reference to `hs` which we need to clean up if
553+ # something goes wrong so we can cleanly shutdown the homeserver.
554+ exit .callback (super ().shutdown )
555+
556+ # An error can be raised here if someone tried to create a `StoreKeyFetcher`
557+ # before the homeserver is fully set up (`HomeServerNotSetupException:
558+ # HomeServer.setup must be called before getting datastores`).
559+ self .store = hs .get_datastores ().main
560+
561+ # We reached the end of the block which means everything was successful, so
562+ # no exit handlers are needed (remove them all).
563+ exit .pop_all ()
527564
528565 async def _fetch_keys (
529566 self , keys_to_fetch : list [_FetchKeyRequest ]
@@ -543,9 +580,21 @@ async def _fetch_keys(
543580
544581class BaseV2KeyFetcher (KeyFetcher ):
545582 def __init__ (self , hs : "HomeServer" ):
546- super ().__init__ (hs )
547-
548- self .store = hs .get_datastores ().main
583+ # Clean-up to avoid partial initialization leaving behind references.
584+ with ExitStack () as exit :
585+ super ().__init__ (hs )
586+ # `KeyFetcher` keeps a reference to `hs` which we need to clean up if
587+ # something goes wrong so we can cleanly shutdown the homeserver.
588+ exit .callback (super ().shutdown )
589+
590+ # An error can be raised here if someone tried to create a `StoreKeyFetcher`
591+ # before the homeserver is fully set up (`HomeServerNotSetupException:
592+ # HomeServer.setup must be called before getting datastores`).
593+ self .store = hs .get_datastores ().main
594+
595+ # We reached the end of the block which means everything was successful, so
596+ # no exit handlers are needed (remove them all).
597+ exit .pop_all ()
549598
550599 async def process_v2_response (
551600 self , from_server : str , response_json : JsonDict , time_added_ms : int
0 commit comments