@@ -600,7 +600,11 @@ async def _get_events_from_cache_or_db(
600600 Returns:
601601 map from event id to result
602602 """
603- event_entry_map = await self ._get_events_from_cache (
603+ # Shortcut: check if we have any events in the *in memory* cache - this function
604+ # may be called repeatedly for the same event so at this point we cannot reach
605+ # out to any external cache for performance reasons. The external cache is
606+ # checked later on in the `get_missing_events_from_cache_or_db` function below.
607+ event_entry_map = self ._get_events_from_local_cache (
604608 event_ids ,
605609 )
606610
@@ -632,7 +636,9 @@ async def _get_events_from_cache_or_db(
632636
633637 if missing_events_ids :
634638
635- async def get_missing_events_from_db () -> Dict [str , EventCacheEntry ]:
639+ async def get_missing_events_from_cache_or_db () -> Dict [
640+ str , EventCacheEntry
641+ ]:
636642 """Fetches the events in `missing_event_ids` from the database.
637643
638644 Also creates entries in `self._current_event_fetches` to allow
@@ -657,10 +663,18 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
657663 # the events have been redacted, and if so pulling the redaction event
658664 # out of the database to check it.
659665 #
666+ missing_events = {}
660667 try :
661- missing_events = await self ._get_events_from_db (
668+ # Try to fetch from any external cache. We already checked the
669+ # in-memory cache above.
670+ missing_events = await self ._get_events_from_external_cache (
662671 missing_events_ids ,
663672 )
673+ # Now actually fetch any remaining events from the DB
674+ db_missing_events = await self ._get_events_from_db (
675+ missing_events_ids - missing_events .keys (),
676+ )
677+ missing_events .update (db_missing_events )
664678 except Exception as e :
665679 with PreserveLoggingContext ():
666680 fetching_deferred .errback (e )
@@ -679,7 +693,7 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
679693 # cancellations, since multiple `_get_events_from_cache_or_db` calls can
680694 # reuse the same fetch.
681695 missing_events : Dict [str , EventCacheEntry ] = await delay_cancellation (
682- get_missing_events_from_db ()
696+ get_missing_events_from_cache_or_db ()
683697 )
684698 event_entry_map .update (missing_events )
685699
@@ -754,7 +768,54 @@ def _invalidate_local_get_event_cache(self, event_id: str) -> None:
754768 async def _get_events_from_cache (
755769 self , events : Iterable [str ], update_metrics : bool = True
756770 ) -> Dict [str , EventCacheEntry ]:
757- """Fetch events from the caches.
771+ """Fetch events from the caches, both in memory and any external.
772+
773+ May return rejected events.
774+
775+ Args:
776+ events: list of event_ids to fetch
777+ update_metrics: Whether to update the cache hit ratio metrics
778+ """
779+ event_map = self ._get_events_from_local_cache (
780+ events , update_metrics = update_metrics
781+ )
782+
783+ missing_event_ids = (e for e in events if e not in event_map )
784+ event_map .update (
785+ await self ._get_events_from_external_cache (
786+ events = missing_event_ids ,
787+ update_metrics = update_metrics ,
788+ )
789+ )
790+
791+ return event_map
792+
793+ async def _get_events_from_external_cache (
794+ self , events : Iterable [str ], update_metrics : bool = True
795+ ) -> Dict [str , EventCacheEntry ]:
796+ """Fetch events from any configured external cache.
797+
798+ May return rejected events.
799+
800+ Args:
801+ events: list of event_ids to fetch
802+ update_metrics: Whether to update the cache hit ratio metrics
803+ """
804+ event_map = {}
805+
806+ for event_id in events :
807+ ret = await self ._get_event_cache .get_external (
808+ (event_id ,), None , update_metrics = update_metrics
809+ )
810+ if ret :
811+ event_map [event_id ] = ret
812+
813+ return event_map
814+
815+ def _get_events_from_local_cache (
816+ self , events : Iterable [str ], update_metrics : bool = True
817+ ) -> Dict [str , EventCacheEntry ]:
818+ """Fetch events from the local, in memory, caches.
758819
759820 May return rejected events.
760821
@@ -766,7 +827,7 @@ async def _get_events_from_cache(
766827
767828 for event_id in events :
768829 # First check if it's in the event cache
769- ret = await self ._get_event_cache .get (
830+ ret = self ._get_event_cache .get_local (
770831 (event_id ,), None , update_metrics = update_metrics
771832 )
772833 if ret :
@@ -788,7 +849,7 @@ async def _get_events_from_cache(
788849
789850 # We add the entry back into the cache as we want to keep
790851 # recently queried events in the cache.
791- await self ._get_event_cache .set ((event_id ,), cache_entry )
852+ self ._get_event_cache .set_local ((event_id ,), cache_entry )
792853
793854 return event_map
794855
0 commit comments