@@ -615,7 +615,11 @@ async def _get_events_from_cache_or_db(
615615 Returns:
616616 map from event id to result
617617 """
618- event_entry_map = await self ._get_events_from_cache (
618+ # Shortcut: check if we have any events in the *in memory* cache - this function
619+ # may be called repeatedly for the same event so at this point we cannot reach
620+ # out to any external cache for performance reasons. The external cache is
621+ # checked later on in the `get_missing_events_from_cache_or_db` function below.
622+ event_entry_map = self ._get_events_from_local_cache (
619623 event_ids ,
620624 )
621625
@@ -647,7 +651,9 @@ async def _get_events_from_cache_or_db(
647651
648652 if missing_events_ids :
649653
650- async def get_missing_events_from_db () -> Dict [str , EventCacheEntry ]:
654+ async def get_missing_events_from_cache_or_db () -> Dict [
655+ str , EventCacheEntry
656+ ]:
651657 """Fetches the events in `missing_event_ids` from the database.
652658
653659 Also creates entries in `self._current_event_fetches` to allow
@@ -672,10 +678,18 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
672678 # the events have been redacted, and if so pulling the redaction event
673679 # out of the database to check it.
674680 #
681+ missing_events = {}
675682 try :
676- missing_events = await self ._get_events_from_db (
683+ # Try to fetch from any external cache. We already checked the
684+ # in-memory cache above.
685+ missing_events = await self ._get_events_from_external_cache (
677686 missing_events_ids ,
678687 )
688+ # Now actually fetch any remaining events from the DB
689+ db_missing_events = await self ._get_events_from_db (
690+ missing_events_ids - missing_events .keys (),
691+ )
692+ missing_events .update (db_missing_events )
679693 except Exception as e :
680694 with PreserveLoggingContext ():
681695 fetching_deferred .errback (e )
@@ -694,7 +708,7 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
694708 # cancellations, since multiple `_get_events_from_cache_or_db` calls can
695709 # reuse the same fetch.
696710 missing_events : Dict [str , EventCacheEntry ] = await delay_cancellation (
697- get_missing_events_from_db ()
711+ get_missing_events_from_cache_or_db ()
698712 )
699713 event_entry_map .update (missing_events )
700714
@@ -769,7 +783,54 @@ def _invalidate_local_get_event_cache(self, event_id: str) -> None:
769783 async def _get_events_from_cache (
770784 self , events : Iterable [str ], update_metrics : bool = True
771785 ) -> Dict [str , EventCacheEntry ]:
772- """Fetch events from the caches.
786+ """Fetch events from the caches, both in memory and any external.
787+
788+ May return rejected events.
789+
790+ Args:
791+ events: list of event_ids to fetch
792+ update_metrics: Whether to update the cache hit ratio metrics
793+ """
794+ event_map = self ._get_events_from_local_cache (
795+ events , update_metrics = update_metrics
796+ )
797+
798+ missing_event_ids = (e for e in events if e not in event_map )
799+ event_map .update (
800+ await self ._get_events_from_external_cache (
801+ events = missing_event_ids ,
802+ update_metrics = update_metrics ,
803+ )
804+ )
805+
806+ return event_map
807+
808+ async def _get_events_from_external_cache (
809+ self , events : Iterable [str ], update_metrics : bool = True
810+ ) -> Dict [str , EventCacheEntry ]:
811+ """Fetch events from any configured external cache.
812+
813+ May return rejected events.
814+
815+ Args:
816+ events: list of event_ids to fetch
817+ update_metrics: Whether to update the cache hit ratio metrics
818+ """
819+ event_map = {}
820+
821+ for event_id in events :
822+ ret = await self ._get_event_cache .get_external (
823+ (event_id ,), None , update_metrics = update_metrics
824+ )
825+ if ret :
826+ event_map [event_id ] = ret
827+
828+ return event_map
829+
830+ def _get_events_from_local_cache (
831+ self , events : Iterable [str ], update_metrics : bool = True
832+ ) -> Dict [str , EventCacheEntry ]:
833+ """Fetch events from the local, in memory, caches.
773834
774835 May return rejected events.
775836
@@ -781,7 +842,7 @@ async def _get_events_from_cache(
781842
782843 for event_id in events :
783844 # First check if it's in the event cache
784- ret = await self ._get_event_cache .get (
845+ ret = self ._get_event_cache .get_local (
785846 (event_id ,), None , update_metrics = update_metrics
786847 )
787848 if ret :
@@ -803,7 +864,7 @@ async def _get_events_from_cache(
803864
804865 # We add the entry back into the cache as we want to keep
805866 # recently queried events in the cache.
806- await self ._get_event_cache .set ((event_id ,), cache_entry )
867+ self ._get_event_cache .set_local ((event_id ,), cache_entry )
807868
808869 return event_map
809870
0 commit comments