1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from abc import ABC , abstractmethod
15- from typing import TYPE_CHECKING , List , Optional , Tuple
15+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple
1616
1717import attr
1818from immutabledict import immutabledict
@@ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase):
107107 state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
108108 then this is the delta of the state between the two groups.
109109
110- prev_group: If it is known, ``state_group``'s prev_group. Note that this being
111- None does not necessarily mean that ``state_group`` does not have
112- a prev_group!
110+ state_group_deltas: If not empty, this is a dict collecting a mapping of the state
111+ difference between state groups.
113112
114- If the event is a state event, this is normally the same as
115- ``state_group_before_event``.
113+ The keys are a tuple of two integers: the initial group and final state group.
114+ The corresponding value is a state map representing the state delta between
115+ these state groups.
116116
117- If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
118- will always also be ``None``.
117+ The dictionary is expected to have at most two entries with state groups of:
119118
120- Note that this *not* (necessarily) the state group associated with
121- ``_prev_state_ids``.
119+ 1. The state group before the event and after the event.
120+ 2. The state group preceding the state group before the event and the
121+ state group before the event.
122122
123- delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
124- and ``state_group`` .
123+ This information is collected and stored as part of an optimization for persisting
124+ events .
125125
126126 partial_state: if True, we may be storing this event with a temporary,
127127 incomplete state.
128128 """
129129
130130 _storage : "StorageControllers"
131+ state_group_deltas : Dict [Tuple [int , int ], StateMap [str ]]
131132 rejected : Optional [str ] = None
132133 _state_group : Optional [int ] = None
133134 state_group_before_event : Optional [int ] = None
134135 _state_delta_due_to_event : Optional [StateMap [str ]] = None
135- prev_group : Optional [int ] = None
136- delta_ids : Optional [StateMap [str ]] = None
137136 app_service : Optional [ApplicationService ] = None
138137
139138 partial_state : bool = False
@@ -145,16 +144,14 @@ def with_state(
145144 state_group_before_event : Optional [int ],
146145 state_delta_due_to_event : Optional [StateMap [str ]],
147146 partial_state : bool ,
148- prev_group : Optional [int ] = None ,
149- delta_ids : Optional [StateMap [str ]] = None ,
147+ state_group_deltas : Dict [Tuple [int , int ], StateMap [str ]],
150148 ) -> "EventContext" :
151149 return EventContext (
152150 storage = storage ,
153151 state_group = state_group ,
154152 state_group_before_event = state_group_before_event ,
155153 state_delta_due_to_event = state_delta_due_to_event ,
156- prev_group = prev_group ,
157- delta_ids = delta_ids ,
154+ state_group_deltas = state_group_deltas ,
158155 partial_state = partial_state ,
159156 )
160157
@@ -163,7 +160,7 @@ def for_outlier(
163160 storage : "StorageControllers" ,
164161 ) -> "EventContext" :
165162 """Return an EventContext instance suitable for persisting an outlier event"""
166- return EventContext (storage = storage )
163+ return EventContext (storage = storage , state_group_deltas = {} )
167164
168165 async def persist (self , event : EventBase ) -> "EventContext" :
169166 return self
@@ -183,13 +180,15 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
183180 "state_group" : self ._state_group ,
184181 "state_group_before_event" : self .state_group_before_event ,
185182 "rejected" : self .rejected ,
186- "prev_group " : self .prev_group ,
183+ "state_group_deltas " : _encode_state_group_delta ( self .state_group_deltas ) ,
187184 "state_delta_due_to_event" : _encode_state_dict (
188185 self ._state_delta_due_to_event
189186 ),
190- "delta_ids" : _encode_state_dict (self .delta_ids ),
191187 "app_service_id" : self .app_service .id if self .app_service else None ,
192188 "partial_state" : self .partial_state ,
189+ # add dummy delta_ids and prev_group for backwards compatibility
190+ "delta_ids" : None ,
191+ "prev_group" : None ,
193192 }
194193
195194 @staticmethod
@@ -204,17 +203,24 @@ def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext
204203 Returns:
205204 The event context.
206205 """
206+ # workaround for backwards/forwards compatibility: if the input doesn't have a value
207+ # for "state_group_deltas" just assign an empty dict
208+ state_group_deltas = input .get ("state_group_deltas" , None )
209+ if state_group_deltas :
210+ state_group_deltas = _decode_state_group_delta (state_group_deltas )
211+ else :
212+ state_group_deltas = {}
213+
207214 context = EventContext (
208215 # We use the state_group and prev_state_id stuff to pull the
209216 # current_state_ids out of the DB and construct prev_state_ids.
210217 storage = storage ,
211218 state_group = input ["state_group" ],
212219 state_group_before_event = input ["state_group_before_event" ],
213- prev_group = input [ "prev_group" ] ,
220+ state_group_deltas = state_group_deltas ,
214221 state_delta_due_to_event = _decode_state_dict (
215222 input ["state_delta_due_to_event" ]
216223 ),
217- delta_ids = _decode_state_dict (input ["delta_ids" ]),
218224 rejected = input ["rejected" ],
219225 partial_state = input .get ("partial_state" , False ),
220226 )
@@ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
349355 _storage : "StorageControllers"
350356 state_group_before_event : Optional [int ]
351357 state_group_after_event : Optional [int ]
352- state_delta_due_to_event : Optional [dict ]
358+ state_delta_due_to_event : Optional [StateMap [ str ] ]
353359 prev_group_for_state_group_before_event : Optional [int ]
354360 delta_ids_to_state_group_before_event : Optional [StateMap [str ]]
355361 partial_state : bool
@@ -380,26 +386,16 @@ async def batch_persist_unpersisted_contexts(
380386
381387 events_and_persisted_context = []
382388 for event , unpersisted_context in amended_events_and_context :
383- if event .is_state ():
384- context = EventContext (
385- storage = unpersisted_context ._storage ,
386- state_group = unpersisted_context .state_group_after_event ,
387- state_group_before_event = unpersisted_context .state_group_before_event ,
388- state_delta_due_to_event = unpersisted_context .state_delta_due_to_event ,
389- partial_state = unpersisted_context .partial_state ,
390- prev_group = unpersisted_context .state_group_before_event ,
391- delta_ids = unpersisted_context .state_delta_due_to_event ,
392- )
393- else :
394- context = EventContext (
395- storage = unpersisted_context ._storage ,
396- state_group = unpersisted_context .state_group_after_event ,
397- state_group_before_event = unpersisted_context .state_group_before_event ,
398- state_delta_due_to_event = unpersisted_context .state_delta_due_to_event ,
399- partial_state = unpersisted_context .partial_state ,
400- prev_group = unpersisted_context .prev_group_for_state_group_before_event ,
401- delta_ids = unpersisted_context .delta_ids_to_state_group_before_event ,
402- )
389+ state_group_deltas = unpersisted_context ._build_state_group_deltas ()
390+
391+ context = EventContext (
392+ storage = unpersisted_context ._storage ,
393+ state_group = unpersisted_context .state_group_after_event ,
394+ state_group_before_event = unpersisted_context .state_group_before_event ,
395+ state_delta_due_to_event = unpersisted_context .state_delta_due_to_event ,
396+ partial_state = unpersisted_context .partial_state ,
397+ state_group_deltas = state_group_deltas ,
398+ )
403399 events_and_persisted_context .append ((event , context ))
404400 return events_and_persisted_context
405401
@@ -452,28 +448,93 @@ async def persist(self, event: EventBase) -> EventContext:
452448
453449 # if the event isn't a state event the state group doesn't change
454450 if not self .state_delta_due_to_event :
455- state_group_after_event = self .state_group_before_event
451+ self . state_group_after_event = self .state_group_before_event
456452
457453 # otherwise if it is a state event we need to get a state group for it
458454 else :
459- state_group_after_event = await self ._storage .state .store_state_group (
455+ self . state_group_after_event = await self ._storage .state .store_state_group (
460456 event .event_id ,
461457 event .room_id ,
462458 prev_group = self .state_group_before_event ,
463459 delta_ids = self .state_delta_due_to_event ,
464460 current_state_ids = None ,
465461 )
466462
463+ state_group_deltas = self ._build_state_group_deltas ()
464+
467465 return EventContext .with_state (
468466 storage = self ._storage ,
469- state_group = state_group_after_event ,
467+ state_group = self . state_group_after_event ,
470468 state_group_before_event = self .state_group_before_event ,
471469 state_delta_due_to_event = self .state_delta_due_to_event ,
470+ state_group_deltas = state_group_deltas ,
472471 partial_state = self .partial_state ,
473- prev_group = self .state_group_before_event ,
474- delta_ids = self .state_delta_due_to_event ,
475472 )
476473
474+ def _build_state_group_deltas (self ) -> Dict [Tuple [int , int ], StateMap ]:
475+ """
476+ Collect deltas between the state groups associated with this context
477+ """
478+ state_group_deltas = {}
479+
480+ # if we know the state group before the event and after the event, add them and the
481+ # state delta between them to state_group_deltas
482+ if self .state_group_before_event and self .state_group_after_event :
483+ # if we have the state groups we should have the delta
484+ assert self .state_delta_due_to_event is not None
485+ state_group_deltas [
486+ (
487+ self .state_group_before_event ,
488+ self .state_group_after_event ,
489+ )
490+ ] = self .state_delta_due_to_event
491+
492+ # the state group before the event may also have a state group which precedes it, if
493+ # we have that and the state group before the event, add them and the state
494+ # delta between them to state_group_deltas
495+ if (
496+ self .prev_group_for_state_group_before_event
497+ and self .state_group_before_event
498+ ):
499+ # if we have both state groups we should have the delta between them
500+ assert self .delta_ids_to_state_group_before_event is not None
501+ state_group_deltas [
502+ (
503+ self .prev_group_for_state_group_before_event ,
504+ self .state_group_before_event ,
505+ )
506+ ] = self .delta_ids_to_state_group_before_event
507+
508+ return state_group_deltas
509+
510+
511+ def _encode_state_group_delta (
512+ state_group_delta : Dict [Tuple [int , int ], StateMap [str ]]
513+ ) -> List [Tuple [int , int , Optional [List [Tuple [str , str , str ]]]]]:
514+ if not state_group_delta :
515+ return []
516+
517+ state_group_delta_encoded = []
518+ for key , value in state_group_delta .items ():
519+ state_group_delta_encoded .append ((key [0 ], key [1 ], _encode_state_dict (value )))
520+
521+ return state_group_delta_encoded
522+
523+
524+ def _decode_state_group_delta (
525+ input : List [Tuple [int , int , List [Tuple [str , str , str ]]]]
526+ ) -> Dict [Tuple [int , int ], StateMap [str ]]:
527+ if not input :
528+ return {}
529+
530+ state_group_deltas = {}
531+ for state_group_1 , state_group_2 , state_dict in input :
532+ state_map = _decode_state_dict (state_dict )
533+ assert state_map is not None
534+ state_group_deltas [(state_group_1 , state_group_2 )] = state_map
535+
536+ return state_group_deltas
537+
477538
478539def _encode_state_dict (
479540 state_dict : Optional [StateMap [str ]],
0 commit comments