@@ -278,15 +278,19 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
278278 )
279279
280280 try :
281- await self ._process_received_pdu (origin , pdu , state_ids = None )
281+ await self ._process_received_pdu (
282+ origin , pdu , state_ids = None , partial_state = None
283+ )
282284 except PartialStateConflictError :
283285 # The room was un-partial stated while we were processing the PDU.
284286 # Try once more, with full state this time.
285287 logger .info (
286288 "Room %s was un-partial stated while processing the PDU, trying again." ,
287289 room_id ,
288290 )
289- await self ._process_received_pdu (origin , pdu , state_ids = None )
291+ await self ._process_received_pdu (
292+ origin , pdu , state_ids = None , partial_state = None
293+ )
290294
291295 async def on_send_membership_event (
292296 self , origin : str , event : EventBase
@@ -534,14 +538,36 @@ async def update_state_for_partial_state_event(
534538 #
535539 # This is the same operation as we do when we receive a regular event
536540 # over federation.
537- state_ids = await self ._resolve_state_at_missing_prevs (destination , event )
538-
539- # build a new state group for it if need be
540- context = await self ._state_handler .compute_event_context (
541- event ,
542- state_ids_before_event = state_ids ,
541+ state_ids , partial_state = await self ._resolve_state_at_missing_prevs (
542+ destination , event
543543 )
544- if context .partial_state :
544+
545+ # There are three possible cases for (state_ids, partial_state):
546+ # * `state_ids` and `partial_state` are both `None` if we had all the
547+ # prev_events. The prev_events may or may not have partial state and
548+ # we won't know until we compute the event context.
549+ # * `state_ids` is not `None` and `partial_state` is `False` if we were
550+ # missing some prev_events (but we have full state for any we did
551+ # have). We calculated the full state after the prev_events.
552+ # * `state_ids` is not `None` and `partial_state` is `True` if we were
553+ # missing some, but not all, prev_events. At least one of the
554+ # prev_events we did have had partial state, so we calculated a partial
555+ # state after the prev_events.
556+
557+ context = None
558+ if state_ids is not None and partial_state :
559+ # the state after the prev events is still partial. We can't de-partial
560+ # state the event, so don't bother building the event context.
561+ pass
562+ else :
563+ # build a new state group for it if need be
564+ context = await self ._state_handler .compute_event_context (
565+ event ,
566+ state_ids_before_event = state_ids ,
567+ partial_state = partial_state ,
568+ )
569+
570+ if context is None or context .partial_state :
545571 # this can happen if some or all of the event's prev_events still have
546572 # partial state - ie, an event has an earlier stream_ordering than one
547573 # or more of its prev_events, so we de-partial-state it before its
@@ -806,14 +832,39 @@ async def _process_pulled_event(
806832 return
807833
808834 try :
809- state_ids = await self ._resolve_state_at_missing_prevs (origin , event )
810- # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
811- # not return partial state
812- # https://github.com/matrix-org/synapse/issues/13002
835+ try :
836+ state_ids , partial_state = await self ._resolve_state_at_missing_prevs (
837+ origin , event
838+ )
839+ await self ._process_received_pdu (
840+ origin ,
841+ event ,
842+ state_ids = state_ids ,
843+ partial_state = partial_state ,
844+ backfilled = backfilled ,
845+ )
846+ except PartialStateConflictError :
847+ # The room was un-partial stated while we were processing the event.
848+ # Try once more, with full state this time.
849+ state_ids , partial_state = await self ._resolve_state_at_missing_prevs (
850+ origin , event
851+ )
813852
814- await self ._process_received_pdu (
815- origin , event , state_ids = state_ids , backfilled = backfilled
816- )
853+ # We ought to have full state now, barring some unlikely race where we left and
854+ # rejoned the room in the background.
855+ if state_ids is not None and partial_state :
856+ raise AssertionError (
857+ f"Event { event .event_id } still has a partial resolved state "
858+ f"after room { event .room_id } was un-partial stated"
859+ )
860+
861+ await self ._process_received_pdu (
862+ origin ,
863+ event ,
864+ state_ids = state_ids ,
865+ partial_state = partial_state ,
866+ backfilled = backfilled ,
867+ )
817868 except FederationError as e :
818869 if e .code == 403 :
819870 logger .warning ("Pulled event %s failed history check." , event_id )
@@ -822,7 +873,7 @@ async def _process_pulled_event(
822873
823874 async def _resolve_state_at_missing_prevs (
824875 self , dest : str , event : EventBase
825- ) -> Optional [StateMap [str ]]:
876+ ) -> Tuple [ Optional [StateMap [str ]], Optional [ bool ]]:
826877 """Calculate the state at an event with missing prev_events.
827878
828879 This is used when we have pulled a batch of events from a remote server, and
@@ -849,8 +900,10 @@ async def _resolve_state_at_missing_prevs(
849900 event: an event to check for missing prevs.
850901
851902 Returns:
852- if we already had all the prev events, `None`. Otherwise, returns
853- the event ids of the state at `event`.
903+ if we already had all the prev events, `None, None`. Otherwise, returns a
904+ tuple containing:
905+ * the event ids of the state at `event`.
906+ * a boolean indicating whether the state may be partial.
854907
855908 Raises:
856909 FederationError if we fail to get the state from the remote server after any
@@ -864,7 +917,7 @@ async def _resolve_state_at_missing_prevs(
864917 missing_prevs = prevs - seen
865918
866919 if not missing_prevs :
867- return None
920+ return None , None
868921
869922 logger .info (
870923 "Event %s is missing prev_events %s: calculating state for a "
@@ -876,9 +929,15 @@ async def _resolve_state_at_missing_prevs(
876929 # resolve them to find the correct state at the current event.
877930
878931 try :
932+ # Determine whether we may be about to retrieve partial state
933+ # Events may be un-partial stated right after we compute the partial state
934+ # flag, but that's okay, as long as the flag errs on the conservative side.
935+ partial_state_flags = await self ._store .get_partial_state_events (seen )
936+ partial_state = any (partial_state_flags .values ())
937+
879938 # Get the state of the events we know about
880939 ours = await self ._state_storage_controller .get_state_groups_ids (
881- room_id , seen
940+ room_id , seen , await_full_state = False
882941 )
883942
884943 # state_maps is a list of mappings from (type, state_key) to event_id
@@ -924,7 +983,7 @@ async def _resolve_state_at_missing_prevs(
924983 "We can't get valid state history." ,
925984 affected = event_id ,
926985 )
927- return state_map
986+ return state_map , partial_state
928987
929988 async def _get_state_ids_after_missing_prev_event (
930989 self ,
@@ -1094,6 +1153,7 @@ async def _process_received_pdu(
10941153 origin : str ,
10951154 event : EventBase ,
10961155 state_ids : Optional [StateMap [str ]],
1156+ partial_state : Optional [bool ],
10971157 backfilled : bool = False ,
10981158 ) -> None :
10991159 """Called when we have a new non-outlier event.
@@ -1117,21 +1177,29 @@ async def _process_received_pdu(
11171177
11181178 state_ids: Normally None, but if we are handling a gap in the graph
11191179 (ie, we are missing one or more prev_events), the resolved state at the
1120- event. Must not be partial state.
1180+ event
1181+
1182+ partial_state:
1183+ `True` if `state_ids` is partial and omits non-critical membership
1184+ events.
1185+ `False` if `state_ids` is the full state.
1186+ `None` if `state_ids` is not provided. In this case, the flag will be
1187+ calculated based on `event`'s prev events.
11211188
11221189 backfilled: True if this is part of a historical batch of events (inhibits
11231190 notification to clients, and validation of device keys.)
11241191
11251192 PartialStateConflictError: if the room was un-partial stated in between
11261193 computing the state at the event and persisting it. The caller should retry
1127- exactly once in this case. Will never be raised if `state_ids` is provided.
1194+ exactly once in this case.
11281195 """
11291196 logger .debug ("Processing event: %s" , event )
11301197 assert not event .internal_metadata .outlier
11311198
11321199 context = await self ._state_handler .compute_event_context (
11331200 event ,
11341201 state_ids_before_event = state_ids ,
1202+ partial_state = partial_state ,
11351203 )
11361204 try :
11371205 await self ._check_event_auth (origin , event , context )
0 commit comments