Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 335ebb2

Browse files
squahtxrichvdh
andauthored
Faster room joins: avoid blocking when pulling events with missing prevs (#13355)
Avoid blocking on full state in `_resolve_state_at_missing_prevs` and return a new flag indicating whether the resolved state is partial. Thread that flag around so that it makes it into the event context. Co-authored-by: Richard van der Hoff <[email protected]>
1 parent 8b60329 commit 335ebb2

File tree

8 files changed

+124
-33
lines changed

8 files changed

+124
-33
lines changed

changelog.d/13355.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Faster room joins: avoid blocking when pulling events with partially missing prev events.

synapse/handlers/federation_event.py

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,19 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
278278
)
279279

280280
try:
281-
await self._process_received_pdu(origin, pdu, state_ids=None)
281+
await self._process_received_pdu(
282+
origin, pdu, state_ids=None, partial_state=None
283+
)
282284
except PartialStateConflictError:
283285
# The room was un-partial stated while we were processing the PDU.
284286
# Try once more, with full state this time.
285287
logger.info(
286288
"Room %s was un-partial stated while processing the PDU, trying again.",
287289
room_id,
288290
)
289-
await self._process_received_pdu(origin, pdu, state_ids=None)
291+
await self._process_received_pdu(
292+
origin, pdu, state_ids=None, partial_state=None
293+
)
290294

291295
async def on_send_membership_event(
292296
self, origin: str, event: EventBase
@@ -534,14 +538,36 @@ async def update_state_for_partial_state_event(
534538
#
535539
# This is the same operation as we do when we receive a regular event
536540
# over federation.
537-
state_ids = await self._resolve_state_at_missing_prevs(destination, event)
538-
539-
# build a new state group for it if need be
540-
context = await self._state_handler.compute_event_context(
541-
event,
542-
state_ids_before_event=state_ids,
541+
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
542+
destination, event
543543
)
544-
if context.partial_state:
544+
545+
# There are three possible cases for (state_ids, partial_state):
546+
# * `state_ids` and `partial_state` are both `None` if we had all the
547+
# prev_events. The prev_events may or may not have partial state and
548+
# we won't know until we compute the event context.
549+
# * `state_ids` is not `None` and `partial_state` is `False` if we were
550+
# missing some prev_events (but we have full state for any we did
551+
# have). We calculated the full state after the prev_events.
552+
# * `state_ids` is not `None` and `partial_state` is `True` if we were
553+
# missing some, but not all, prev_events. At least one of the
554+
# prev_events we did have had partial state, so we calculated a partial
555+
# state after the prev_events.
556+
557+
context = None
558+
if state_ids is not None and partial_state:
559+
# the state after the prev events is still partial. We can't de-partial
560+
# state the event, so don't bother building the event context.
561+
pass
562+
else:
563+
# build a new state group for it if need be
564+
context = await self._state_handler.compute_event_context(
565+
event,
566+
state_ids_before_event=state_ids,
567+
partial_state=partial_state,
568+
)
569+
570+
if context is None or context.partial_state:
545571
# this can happen if some or all of the event's prev_events still have
546572
# partial state - ie, an event has an earlier stream_ordering than one
547573
# or more of its prev_events, so we de-partial-state it before its
@@ -806,14 +832,39 @@ async def _process_pulled_event(
806832
return
807833

808834
try:
809-
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
810-
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
811-
# not return partial state
812-
# https://github.com/matrix-org/synapse/issues/13002
835+
try:
836+
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
837+
origin, event
838+
)
839+
await self._process_received_pdu(
840+
origin,
841+
event,
842+
state_ids=state_ids,
843+
partial_state=partial_state,
844+
backfilled=backfilled,
845+
)
846+
except PartialStateConflictError:
847+
# The room was un-partial stated while we were processing the event.
848+
# Try once more, with full state this time.
849+
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
850+
origin, event
851+
)
813852

814-
await self._process_received_pdu(
815-
origin, event, state_ids=state_ids, backfilled=backfilled
816-
)
853+
# We ought to have full state now, barring some unlikely race where we left and
854+
# rejoned the room in the background.
855+
if state_ids is not None and partial_state:
856+
raise AssertionError(
857+
f"Event {event.event_id} still has a partial resolved state "
858+
f"after room {event.room_id} was un-partial stated"
859+
)
860+
861+
await self._process_received_pdu(
862+
origin,
863+
event,
864+
state_ids=state_ids,
865+
partial_state=partial_state,
866+
backfilled=backfilled,
867+
)
817868
except FederationError as e:
818869
if e.code == 403:
819870
logger.warning("Pulled event %s failed history check.", event_id)
@@ -822,7 +873,7 @@ async def _process_pulled_event(
822873

823874
async def _resolve_state_at_missing_prevs(
824875
self, dest: str, event: EventBase
825-
) -> Optional[StateMap[str]]:
876+
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
826877
"""Calculate the state at an event with missing prev_events.
827878
828879
This is used when we have pulled a batch of events from a remote server, and
@@ -849,8 +900,10 @@ async def _resolve_state_at_missing_prevs(
849900
event: an event to check for missing prevs.
850901
851902
Returns:
852-
if we already had all the prev events, `None`. Otherwise, returns
853-
the event ids of the state at `event`.
903+
if we already had all the prev events, `None, None`. Otherwise, returns a
904+
tuple containing:
905+
* the event ids of the state at `event`.
906+
* a boolean indicating whether the state may be partial.
854907
855908
Raises:
856909
FederationError if we fail to get the state from the remote server after any
@@ -864,7 +917,7 @@ async def _resolve_state_at_missing_prevs(
864917
missing_prevs = prevs - seen
865918

866919
if not missing_prevs:
867-
return None
920+
return None, None
868921

869922
logger.info(
870923
"Event %s is missing prev_events %s: calculating state for a "
@@ -876,9 +929,15 @@ async def _resolve_state_at_missing_prevs(
876929
# resolve them to find the correct state at the current event.
877930

878931
try:
932+
# Determine whether we may be about to retrieve partial state
933+
# Events may be un-partial stated right after we compute the partial state
934+
# flag, but that's okay, as long as the flag errs on the conservative side.
935+
partial_state_flags = await self._store.get_partial_state_events(seen)
936+
partial_state = any(partial_state_flags.values())
937+
879938
# Get the state of the events we know about
880939
ours = await self._state_storage_controller.get_state_groups_ids(
881-
room_id, seen
940+
room_id, seen, await_full_state=False
882941
)
883942

884943
# state_maps is a list of mappings from (type, state_key) to event_id
@@ -924,7 +983,7 @@ async def _resolve_state_at_missing_prevs(
924983
"We can't get valid state history.",
925984
affected=event_id,
926985
)
927-
return state_map
986+
return state_map, partial_state
928987

929988
async def _get_state_ids_after_missing_prev_event(
930989
self,
@@ -1094,6 +1153,7 @@ async def _process_received_pdu(
10941153
origin: str,
10951154
event: EventBase,
10961155
state_ids: Optional[StateMap[str]],
1156+
partial_state: Optional[bool],
10971157
backfilled: bool = False,
10981158
) -> None:
10991159
"""Called when we have a new non-outlier event.
@@ -1117,21 +1177,29 @@ async def _process_received_pdu(
11171177
11181178
state_ids: Normally None, but if we are handling a gap in the graph
11191179
(ie, we are missing one or more prev_events), the resolved state at the
1120-
event. Must not be partial state.
1180+
event
1181+
1182+
partial_state:
1183+
`True` if `state_ids` is partial and omits non-critical membership
1184+
events.
1185+
`False` if `state_ids` is the full state.
1186+
`None` if `state_ids` is not provided. In this case, the flag will be
1187+
calculated based on `event`'s prev events.
11211188
11221189
backfilled: True if this is part of a historical batch of events (inhibits
11231190
notification to clients, and validation of device keys.)
11241191
11251192
PartialStateConflictError: if the room was un-partial stated in between
11261193
computing the state at the event and persisting it. The caller should retry
1127-
exactly once in this case. Will never be raised if `state_ids` is provided.
1194+
exactly once in this case.
11281195
"""
11291196
logger.debug("Processing event: %s", event)
11301197
assert not event.internal_metadata.outlier
11311198

11321199
context = await self._state_handler.compute_event_context(
11331200
event,
11341201
state_ids_before_event=state_ids,
1202+
partial_state=partial_state,
11351203
)
11361204
try:
11371205
await self._check_event_auth(origin, event, context)

synapse/handlers/message.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,10 @@ async def create_new_client_event(
11351135
context = await self.state.compute_event_context(
11361136
event,
11371137
state_ids_before_event=state_map_for_event,
1138+
# TODO(faster_joins): check how MSC2716 works and whether we can have
1139+
# partial state here
1140+
# https://github.com/matrix-org/synapse/issues/13003
1141+
partial_state=False,
11381142
)
11391143
else:
11401144
context = await self.state.compute_event_context(event)

synapse/state/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ async def compute_event_context(
255255
self,
256256
event: EventBase,
257257
state_ids_before_event: Optional[StateMap[str]] = None,
258-
partial_state: bool = False,
258+
partial_state: Optional[bool] = None,
259259
) -> EventContext:
260260
"""Build an EventContext structure for a non-outlier event.
261261
@@ -270,8 +270,12 @@ async def compute_event_context(
270270
it can't be calculated from existing events. This is normally
271271
only specified when receiving an event from federation where we
272272
don't have the prev events, e.g. when backfilling.
273-
partial_state: True if `state_ids_before_event` is partial and omits
274-
non-critical membership events
273+
partial_state:
274+
`True` if `state_ids_before_event` is partial and omits non-critical
275+
membership events.
276+
`False` if `state_ids_before_event` is the full state.
277+
`None` when `state_ids_before_event` is not provided. In this case, the
278+
flag will be calculated based on `event`'s prev events.
275279
Returns:
276280
The event context.
277281
"""
@@ -298,12 +302,14 @@ async def compute_event_context(
298302
)
299303
)
300304

305+
# the partial_state flag must be provided
306+
assert partial_state is not None
301307
else:
302308
# otherwise, we'll need to resolve the state across the prev_events.
303309

304310
# partial_state should not be set explicitly in this case:
305311
# we work it out dynamically
306-
assert not partial_state
312+
assert partial_state is None
307313

308314
# if any of the prev-events have partial state, so do we.
309315
# (This is slightly racy - the prev-events might get fixed up before we use
@@ -313,13 +319,13 @@ async def compute_event_context(
313319
incomplete_prev_events = await self.store.get_partial_state_events(
314320
prev_event_ids
315321
)
316-
if any(incomplete_prev_events.values()):
322+
partial_state = any(incomplete_prev_events.values())
323+
if partial_state:
317324
logger.debug(
318325
"New/incoming event %s refers to prev_events %s with partial state",
319326
event.event_id,
320327
[k for (k, v) in incomplete_prev_events.items() if v],
321328
)
322-
partial_state = True
323329

324330
logger.debug("calling resolve_state_groups from compute_event_context")
325331
# we've already taken into account partial state, so no need to wait for

synapse/storage/controllers/state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ async def get_state_group_delta(
8282
return state_group_delta.prev_group, state_group_delta.delta_ids
8383

8484
async def get_state_groups_ids(
85-
self, _room_id: str, event_ids: Collection[str]
85+
self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
8686
) -> Dict[int, MutableStateMap[str]]:
8787
"""Get the event IDs of all the state for the state groups for the given events
8888
8989
Args:
9090
_room_id: id of the room for these events
9191
event_ids: ids of the events
92+
await_full_state: if `True`, will block if we do not yet have complete
93+
state at these events.
9294
9395
Returns:
9496
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +102,9 @@ async def get_state_groups_ids(
100102
if not event_ids:
101103
return {}
102104

103-
event_to_groups = await self.get_state_group_for_events(event_ids)
105+
event_to_groups = await self.get_state_group_for_events(
106+
event_ids, await_full_state=await_full_state
107+
)
104108

105109
groups = set(event_to_groups.values())
106110
group_to_state = await self.stores.state._get_state_for_groups(groups)

tests/handlers/test_federation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def test_backfill_with_many_backward_extremities(self) -> None:
287287
state_ids={
288288
(e.type, e.state_key): e.event_id for e in current_state
289289
},
290+
partial_state=False,
290291
)
291292
)
292293

tests/storage/test_events.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def prepare(self, reactor, clock, homeserver):
7070
def persist_event(self, event, state=None):
7171
"""Persist the event, with optional state"""
7272
context = self.get_success(
73-
self.state.compute_event_context(event, state_ids_before_event=state)
73+
self.state.compute_event_context(
74+
event,
75+
state_ids_before_event=state,
76+
partial_state=None if state is None else False,
77+
)
7478
)
7579
self.get_success(self._persistence.persist_event(event, context))
7680

@@ -148,6 +152,7 @@ def test_do_not_prune_gap_if_state_different(self):
148152
self.state.compute_event_context(
149153
remote_event_2,
150154
state_ids_before_event=state_before_gap,
155+
partial_state=False,
151156
)
152157
)
153158

tests/test_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def test_annotate_with_old_message(self):
462462
state_ids_before_event={
463463
(e.type, e.state_key): e.event_id for e in old_state
464464
},
465+
partial_state=False,
465466
)
466467
)
467468

@@ -492,6 +493,7 @@ def test_annotate_with_old_state(self):
492493
state_ids_before_event={
493494
(e.type, e.state_key): e.event_id for e in old_state
494495
},
496+
partial_state=False,
495497
)
496498
)
497499

0 commit comments

Comments
 (0)