1414
1515import logging
1616import threading
17- from collections import namedtuple
1817from typing import (
1918 Collection ,
2019 Container ,
2726 overload ,
2827)
2928
29+ import attr
3030from constantly import NamedConstant , Names
3131from typing_extensions import Literal
3232
4242from synapse .events import EventBase , make_event_from_dict
4343from synapse .events .snapshot import EventContext
4444from synapse .events .utils import prune_event
45- from synapse .logging .context import PreserveLoggingContext , current_context
45+ from synapse .logging .context import (
46+ PreserveLoggingContext ,
47+ current_context ,
48+ make_deferred_yieldable ,
49+ )
4650from synapse .metrics .background_process_metrics import (
4751 run_as_background_process ,
4852 wrap_as_background_process ,
5660from synapse .storage .util .id_generators import MultiWriterIdGenerator , StreamIdGenerator
5761from synapse .storage .util .sequence import build_sequence_generator
5862from synapse .types import JsonDict , get_domain_from_id
63+ from synapse .util import unwrapFirstError
64+ from synapse .util .async_helpers import ObservableDeferred
5965from synapse .util .caches .descriptors import cached , cachedList
6066from synapse .util .caches .lrucache import LruCache
6167from synapse .util .iterutils import batch_iter
7480EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
7581
7682
77- _EventCacheEntry = namedtuple ("_EventCacheEntry" , ("event" , "redacted_event" ))
83+ @attr .s (slots = True , auto_attribs = True )
84+ class _EventCacheEntry :
85+ event : EventBase
86+ redacted_event : Optional [EventBase ]
7887
7988
8089class EventRedactBehaviour (Names ):
@@ -161,6 +170,13 @@ def __init__(self, database: DatabasePool, db_conn, hs):
161170 max_size = hs .config .caches .event_cache_size ,
162171 )
163172
173+ # Map from event ID to a deferred that will result in a map from event
174+ # ID to cache entry. Note that the returned dict may not have the
175+ # requested event in it if the event isn't in the DB.
176+ self ._current_event_fetches : Dict [
177+ str , ObservableDeferred [Dict [str , _EventCacheEntry ]]
178+ ] = {}
179+
164180 self ._event_fetch_lock = threading .Condition ()
165181 self ._event_fetch_list = []
166182 self ._event_fetch_ongoing = 0
@@ -476,7 +492,9 @@ async def get_events_as_list(
476492
477493 return events
478494
479- async def _get_events_from_cache_or_db (self , event_ids , allow_rejected = False ):
495+ async def _get_events_from_cache_or_db (
496+ self , event_ids : Iterable [str ], allow_rejected : bool = False
497+ ) -> Dict [str , _EventCacheEntry ]:
480498 """Fetch a bunch of events from the cache or the database.
481499
482500 If events are pulled from the database, they will be cached for future lookups.
@@ -485,53 +503,107 @@ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
485503
486504 Args:
487505
488- event_ids (Iterable[str]) : The event_ids of the events to fetch
506+ event_ids: The event_ids of the events to fetch
489507
490- allow_rejected (bool) : Whether to include rejected events. If False,
508+ allow_rejected: Whether to include rejected events. If False,
491509 rejected events are omitted from the response.
492510
493511 Returns:
494- Dict[str, _EventCacheEntry]:
495- map from event id to result
512+ map from event id to result
496513 """
497514 event_entry_map = self ._get_events_from_cache (
498- event_ids , allow_rejected = allow_rejected
515+ event_ids ,
499516 )
500517
501- missing_events_ids = [e for e in event_ids if e not in event_entry_map ]
518+ missing_events_ids = {e for e in event_ids if e not in event_entry_map }
519+
520+ # We now look up if we're already fetching some of the events in the DB,
521+ # if so we wait for those lookups to finish instead of pulling the same
522+ # events out of the DB multiple times.
523+ already_fetching : Dict [str , defer .Deferred ] = {}
524+
525+ for event_id in missing_events_ids :
526+ deferred = self ._current_event_fetches .get (event_id )
527+ if deferred is not None :
528+ # We're already pulling the event out of the DB. Add the deferred
529+ # to the collection of deferreds to wait on.
530+ already_fetching [event_id ] = deferred .observe ()
531+
532+ missing_events_ids .difference_update (already_fetching )
502533
503534 if missing_events_ids :
504535 log_ctx = current_context ()
505536 log_ctx .record_event_fetch (len (missing_events_ids ))
506537
538+ # Add entries to `self._current_event_fetches` for each event we're
539+ # going to pull from the DB. We use a single deferred that resolves
540+ # to all the events we pulled from the DB (this will result in this
541+ # function returning more events than requested, but that can happen
542+ # already due to `_get_events_from_db`).
543+ fetching_deferred : ObservableDeferred [
544+ Dict [str , _EventCacheEntry ]
545+ ] = ObservableDeferred (defer .Deferred ())
546+ for event_id in missing_events_ids :
547+ self ._current_event_fetches [event_id ] = fetching_deferred
548+
507549 # Note that _get_events_from_db is also responsible for turning db rows
508550 # into FrozenEvents (via _get_event_from_row), which involves seeing if
509551 # the events have been redacted, and if so pulling the redaction event out
510552 # of the database to check it.
511553 #
512- missing_events = await self ._get_events_from_db (
513- missing_events_ids , allow_rejected = allow_rejected
514- )
554+ try :
555+ missing_events = await self ._get_events_from_db (
556+ missing_events_ids ,
557+ )
515558
516- event_entry_map .update (missing_events )
559+ event_entry_map .update (missing_events )
560+ except Exception as e :
561+ with PreserveLoggingContext ():
562+ fetching_deferred .errback (e )
563+ raise e
564+ finally :
565+ # Ensure that we mark these events as no longer being fetched.
566+ for event_id in missing_events_ids :
567+ self ._current_event_fetches .pop (event_id , None )
568+
569+ with PreserveLoggingContext ():
570+ fetching_deferred .callback (missing_events )
571+
572+ if already_fetching :
573+ # Wait for the other event requests to finish and add their results
574+ # to ours.
575+ results = await make_deferred_yieldable (
576+ defer .gatherResults (
577+ already_fetching .values (),
578+ consumeErrors = True ,
579+ )
580+ ).addErrback (unwrapFirstError )
581+
582+ for result in results :
583+ event_entry_map .update (result )
584+
585+ if not allow_rejected :
586+ event_entry_map = {
587+ event_id : entry
588+ for event_id , entry in event_entry_map .items ()
589+ if not entry .event .rejected_reason
590+ }
517591
518592 return event_entry_map
519593
520594 def _invalidate_get_event_cache (self , event_id ):
521595 self ._get_event_cache .invalidate ((event_id ,))
522596
523- def _get_events_from_cache (self , events , allow_rejected , update_metrics = True ):
524- """Fetch events from the caches
597+ def _get_events_from_cache (
598+ self , events : Iterable [str ], update_metrics : bool = True
599+ ) -> Dict [str , _EventCacheEntry ]:
600+ """Fetch events from the caches.
525601
526- Args:
527- events (Iterable[str]): list of event_ids to fetch
528- allow_rejected (bool): Whether to return events that were rejected
529- update_metrics (bool): Whether to update the cache hit ratio metrics
602+ May return rejected events.
530603
531- Returns:
532- dict of event_id -> _EventCacheEntry for each event_id in cache. If
533- allow_rejected is `False` then there will still be an entry but it
534- will be `None`
604+ Args:
605+ events: list of event_ids to fetch
606+ update_metrics: Whether to update the cache hit ratio metrics
535607 """
536608 event_map = {}
537609
@@ -542,10 +614,7 @@ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
542614 if not ret :
543615 continue
544616
545- if allow_rejected or not ret .event .rejected_reason :
546- event_map [event_id ] = ret
547- else :
548- event_map [event_id ] = None
617+ event_map [event_id ] = ret
549618
550619 return event_map
551620
@@ -672,23 +741,23 @@ def fire(evs, exc):
672741 with PreserveLoggingContext ():
673742 self .hs .get_reactor ().callFromThread (fire , event_list , e )
674743
675- async def _get_events_from_db (self , event_ids , allow_rejected = False ):
744+ async def _get_events_from_db (
745+ self , event_ids : Iterable [str ]
746+ ) -> Dict [str , _EventCacheEntry ]:
676747 """Fetch a bunch of events from the database.
677748
749+ May return rejected events.
750+
678751 Returned events will be added to the cache for future lookups.
679752
680753 Unknown events are omitted from the response.
681754
682755 Args:
683- event_ids (Iterable[str]): The event_ids of the events to fetch
684-
685- allow_rejected (bool): Whether to include rejected events. If False,
686- rejected events are omitted from the response.
756+ event_ids: The event_ids of the events to fetch
687757
688758 Returns:
689- Dict[str, _EventCacheEntry]:
690- map from event id to result. May return extra events which
691- weren't asked for.
759+ map from event id to result. May return extra events which
760+ weren't asked for.
692761 """
693762 fetched_events = {}
694763 events_to_fetch = event_ids
@@ -717,9 +786,6 @@ async def _get_events_from_db(self, event_ids, allow_rejected=False):
717786
718787 rejected_reason = row ["rejected_reason" ]
719788
720- if not allow_rejected and rejected_reason :
721- continue
722-
723789 # If the event or metadata cannot be parsed, log the error and act
724790 # as if the event is unknown.
725791 try :
0 commit comments