Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit c37dad6

Browse files
authored
Improve event caching code (#10119)
Ensure we only load an event from the DB once when the same event is requested multiple times at once.
1 parent 11540be commit c37dad6

File tree

4 files changed

+158
-43
lines changed

4 files changed

+158
-43
lines changed

changelog.d/10119.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.

synapse/storage/databases/main/events_worker.py

Lines changed: 105 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import threading
17-
from collections import namedtuple
1817
from typing import (
1918
Collection,
2019
Container,
@@ -27,6 +26,7 @@
2726
overload,
2827
)
2928

29+
import attr
3030
from constantly import NamedConstant, Names
3131
from typing_extensions import Literal
3232

@@ -42,7 +42,11 @@
4242
from synapse.events import EventBase, make_event_from_dict
4343
from synapse.events.snapshot import EventContext
4444
from synapse.events.utils import prune_event
45-
from synapse.logging.context import PreserveLoggingContext, current_context
45+
from synapse.logging.context import (
46+
PreserveLoggingContext,
47+
current_context,
48+
make_deferred_yieldable,
49+
)
4650
from synapse.metrics.background_process_metrics import (
4751
run_as_background_process,
4852
wrap_as_background_process,
@@ -56,6 +60,8 @@
5660
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
5761
from synapse.storage.util.sequence import build_sequence_generator
5862
from synapse.types import JsonDict, get_domain_from_id
63+
from synapse.util import unwrapFirstError
64+
from synapse.util.async_helpers import ObservableDeferred
5965
from synapse.util.caches.descriptors import cached, cachedList
6066
from synapse.util.caches.lrucache import LruCache
6167
from synapse.util.iterutils import batch_iter
@@ -74,7 +80,10 @@
7480
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
7581

7682

77-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
83+
@attr.s(slots=True, auto_attribs=True)
84+
class _EventCacheEntry:
85+
event: EventBase
86+
redacted_event: Optional[EventBase]
7887

7988

8089
class EventRedactBehaviour(Names):
@@ -161,6 +170,13 @@ def __init__(self, database: DatabasePool, db_conn, hs):
161170
max_size=hs.config.caches.event_cache_size,
162171
)
163172

173+
# Map from event ID to a deferred that will result in a map from event
174+
# ID to cache entry. Note that the returned dict may not have the
175+
# requested event in it if the event isn't in the DB.
176+
self._current_event_fetches: Dict[
177+
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
178+
] = {}
179+
164180
self._event_fetch_lock = threading.Condition()
165181
self._event_fetch_list = []
166182
self._event_fetch_ongoing = 0
@@ -476,7 +492,9 @@ async def get_events_as_list(
476492

477493
return events
478494

479-
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
495+
async def _get_events_from_cache_or_db(
496+
self, event_ids: Iterable[str], allow_rejected: bool = False
497+
) -> Dict[str, _EventCacheEntry]:
480498
"""Fetch a bunch of events from the cache or the database.
481499
482500
If events are pulled from the database, they will be cached for future lookups.
@@ -485,53 +503,107 @@ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
485503
486504
Args:
487505
488-
event_ids (Iterable[str]): The event_ids of the events to fetch
506+
event_ids: The event_ids of the events to fetch
489507
490-
allow_rejected (bool): Whether to include rejected events. If False,
508+
allow_rejected: Whether to include rejected events. If False,
491509
rejected events are omitted from the response.
492510
493511
Returns:
494-
Dict[str, _EventCacheEntry]:
495-
map from event id to result
512+
map from event id to result
496513
"""
497514
event_entry_map = self._get_events_from_cache(
498-
event_ids, allow_rejected=allow_rejected
515+
event_ids,
499516
)
500517

501-
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
518+
missing_events_ids = {e for e in event_ids if e not in event_entry_map}
519+
520+
# We now look up if we're already fetching some of the events in the DB,
521+
# if so we wait for those lookups to finish instead of pulling the same
522+
# events out of the DB multiple times.
523+
already_fetching: Dict[str, defer.Deferred] = {}
524+
525+
for event_id in missing_events_ids:
526+
deferred = self._current_event_fetches.get(event_id)
527+
if deferred is not None:
528+
# We're already pulling the event out of the DB. Add the deferred
529+
# to the collection of deferreds to wait on.
530+
already_fetching[event_id] = deferred.observe()
531+
532+
missing_events_ids.difference_update(already_fetching)
502533

503534
if missing_events_ids:
504535
log_ctx = current_context()
505536
log_ctx.record_event_fetch(len(missing_events_ids))
506537

538+
# Add entries to `self._current_event_fetches` for each event we're
539+
# going to pull from the DB. We use a single deferred that resolves
540+
# to all the events we pulled from the DB (this will result in this
541+
# function returning more events than requested, but that can happen
542+
# already due to `_get_events_from_db`).
543+
fetching_deferred: ObservableDeferred[
544+
Dict[str, _EventCacheEntry]
545+
] = ObservableDeferred(defer.Deferred())
546+
for event_id in missing_events_ids:
547+
self._current_event_fetches[event_id] = fetching_deferred
548+
507549
# Note that _get_events_from_db is also responsible for turning db rows
508550
# into FrozenEvents (via _get_event_from_row), which involves seeing if
509551
# the events have been redacted, and if so pulling the redaction event out
510552
# of the database to check it.
511553
#
512-
missing_events = await self._get_events_from_db(
513-
missing_events_ids, allow_rejected=allow_rejected
514-
)
554+
try:
555+
missing_events = await self._get_events_from_db(
556+
missing_events_ids,
557+
)
515558

516-
event_entry_map.update(missing_events)
559+
event_entry_map.update(missing_events)
560+
except Exception as e:
561+
with PreserveLoggingContext():
562+
fetching_deferred.errback(e)
563+
raise e
564+
finally:
565+
# Ensure that we mark these events as no longer being fetched.
566+
for event_id in missing_events_ids:
567+
self._current_event_fetches.pop(event_id, None)
568+
569+
with PreserveLoggingContext():
570+
fetching_deferred.callback(missing_events)
571+
572+
if already_fetching:
573+
# Wait for the other event requests to finish and add their results
574+
# to ours.
575+
results = await make_deferred_yieldable(
576+
defer.gatherResults(
577+
already_fetching.values(),
578+
consumeErrors=True,
579+
)
580+
).addErrback(unwrapFirstError)
581+
582+
for result in results:
583+
event_entry_map.update(result)
584+
585+
if not allow_rejected:
586+
event_entry_map = {
587+
event_id: entry
588+
for event_id, entry in event_entry_map.items()
589+
if not entry.event.rejected_reason
590+
}
517591

518592
return event_entry_map
519593

520594
def _invalidate_get_event_cache(self, event_id):
521595
self._get_event_cache.invalidate((event_id,))
522596

523-
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
524-
"""Fetch events from the caches
597+
def _get_events_from_cache(
598+
self, events: Iterable[str], update_metrics: bool = True
599+
) -> Dict[str, _EventCacheEntry]:
600+
"""Fetch events from the caches.
525601
526-
Args:
527-
events (Iterable[str]): list of event_ids to fetch
528-
allow_rejected (bool): Whether to return events that were rejected
529-
update_metrics (bool): Whether to update the cache hit ratio metrics
602+
May return rejected events.
530603
531-
Returns:
532-
dict of event_id -> _EventCacheEntry for each event_id in cache. If
533-
allow_rejected is `False` then there will still be an entry but it
534-
will be `None`
604+
Args:
605+
events: list of event_ids to fetch
606+
update_metrics: Whether to update the cache hit ratio metrics
535607
"""
536608
event_map = {}
537609

@@ -542,10 +614,7 @@ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
542614
if not ret:
543615
continue
544616

545-
if allow_rejected or not ret.event.rejected_reason:
546-
event_map[event_id] = ret
547-
else:
548-
event_map[event_id] = None
617+
event_map[event_id] = ret
549618

550619
return event_map
551620

@@ -672,23 +741,23 @@ def fire(evs, exc):
672741
with PreserveLoggingContext():
673742
self.hs.get_reactor().callFromThread(fire, event_list, e)
674743

675-
async def _get_events_from_db(self, event_ids, allow_rejected=False):
744+
async def _get_events_from_db(
745+
self, event_ids: Iterable[str]
746+
) -> Dict[str, _EventCacheEntry]:
676747
"""Fetch a bunch of events from the database.
677748
749+
May return rejected events.
750+
678751
Returned events will be added to the cache for future lookups.
679752
680753
Unknown events are omitted from the response.
681754
682755
Args:
683-
event_ids (Iterable[str]): The event_ids of the events to fetch
684-
685-
allow_rejected (bool): Whether to include rejected events. If False,
686-
rejected events are omitted from the response.
756+
event_ids: The event_ids of the events to fetch
687757
688758
Returns:
689-
Dict[str, _EventCacheEntry]:
690-
map from event id to result. May return extra events which
691-
weren't asked for.
759+
map from event id to result. May return extra events which
760+
weren't asked for.
692761
"""
693762
fetched_events = {}
694763
events_to_fetch = event_ids
@@ -717,9 +786,6 @@ async def _get_events_from_db(self, event_ids, allow_rejected=False):
717786

718787
rejected_reason = row["rejected_reason"]
719788

720-
if not allow_rejected and rejected_reason:
721-
continue
722-
723789
# If the event or metadata cannot be parsed, log the error and act
724790
# as if the event is unknown.
725791
try:

synapse/storage/databases/main/roommember.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,14 +629,12 @@ async def _get_joined_users_from_context(
629629
# We don't update the event cache hit ratio as it completely throws off
630630
# the hit ratio counts. After all, we don't populate the cache if we
631631
# miss it here
632-
event_map = self._get_events_from_cache(
633-
member_event_ids, allow_rejected=False, update_metrics=False
634-
)
632+
event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
635633

636634
missing_member_event_ids = []
637635
for event_id in member_event_ids:
638636
ev_entry = event_map.get(event_id)
639-
if ev_entry:
637+
if ev_entry and not ev_entry.event.rejected_reason:
640638
if ev_entry.event.membership == Membership.JOIN:
641639
users_in_room[ev_entry.event.state_key] = ProfileInfo(
642640
display_name=ev_entry.event.content.get("displayname", None),

tests/storage/databases/main/test_events_worker.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
import json
1515

1616
from synapse.logging.context import LoggingContext
17+
from synapse.rest import admin
18+
from synapse.rest.client.v1 import login, room
1719
from synapse.storage.databases.main.events_worker import EventsWorkerStore
20+
from synapse.util.async_helpers import yieldable_gather_results
1821

1922
from tests import unittest
2023

@@ -94,3 +97,50 @@ def test_query_via_event_cache(self):
9497
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
9598
self.assertEquals(res, {"event10"})
9699
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
100+
101+
102+
class EventCacheTestCase(unittest.HomeserverTestCase):
103+
"""Test that the various layers of event cache works."""
104+
105+
servlets = [
106+
admin.register_servlets,
107+
room.register_servlets,
108+
login.register_servlets,
109+
]
110+
111+
def prepare(self, reactor, clock, hs):
112+
self.store: EventsWorkerStore = hs.get_datastore()
113+
114+
self.user = self.register_user("user", "pass")
115+
self.token = self.login(self.user, "pass")
116+
117+
self.room = self.helper.create_room_as(self.user, tok=self.token)
118+
119+
res = self.helper.send(self.room, tok=self.token)
120+
self.event_id = res["event_id"]
121+
122+
# Reset the event cache so the tests start with it empty
123+
self.store._get_event_cache.clear()
124+
125+
def test_simple(self):
126+
"""Test that we cache events that we pull from the DB."""
127+
128+
with LoggingContext("test") as ctx:
129+
self.get_success(self.store.get_event(self.event_id))
130+
131+
# We should have fetched the event from the DB
132+
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
133+
134+
def test_dedupe(self):
135+
"""Test that if we request the same event multiple times we only pull it
136+
out once.
137+
"""
138+
139+
with LoggingContext("test") as ctx:
140+
d = yieldable_gather_results(
141+
self.store.get_event, [self.event_id, self.event_id]
142+
)
143+
self.get_success(d)
144+
145+
# We should have fetched the event from the DB
146+
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)

0 commit comments

Comments
 (0)