|  | 
| 11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
|  | 14 | +from abc import ABC, abstractmethod | 
| 14 | 15 | from typing import TYPE_CHECKING, List, Optional, Tuple | 
| 15 | 16 | 
 | 
| 16 | 17 | import attr | 
|  | 
| 26 | 27 |     from synapse.types.state import StateFilter | 
| 27 | 28 | 
 | 
| 28 | 29 | 
 | 
|  | 30 | +class UnpersistedEventContextBase(ABC): | 
|  | 31 | +    """ | 
|  | 32 | +    This is a base class for EventContext and UnpersistedEventContext, objects which | 
|  | 33 | +    hold information relevant to storing an associated event. Note that an | 
|  | 34 | +    UnpersistedEventContexts must be converted into an EventContext before it is | 
|  | 35 | +    suitable to send to the db with its associated event. | 
|  | 36 | +
 | 
|  | 37 | +    Attributes: | 
|  | 38 | +        _storage: storage controllers for interfacing with the database | 
|  | 39 | +        app_service: If the associated event is being sent by a (local) application service, that | 
|  | 40 | +            app service. | 
|  | 41 | +    """ | 
|  | 42 | + | 
|  | 43 | +    def __init__(self, storage_controller: "StorageControllers"): | 
|  | 44 | +        self._storage: "StorageControllers" = storage_controller | 
|  | 45 | +        self.app_service: Optional[ApplicationService] = None | 
|  | 46 | + | 
|  | 47 | +    @abstractmethod | 
|  | 48 | +    async def persist( | 
|  | 49 | +        self, | 
|  | 50 | +        event: EventBase, | 
|  | 51 | +    ) -> "EventContext": | 
|  | 52 | +        """ | 
|  | 53 | +        A method to convert an UnpersistedEventContext to an EventContext, suitable for | 
|  | 54 | +        sending to the database with the associated event. | 
|  | 55 | +        """ | 
|  | 56 | +        pass | 
|  | 57 | + | 
|  | 58 | +    @abstractmethod | 
|  | 59 | +    async def get_prev_state_ids( | 
|  | 60 | +        self, state_filter: Optional["StateFilter"] = None | 
|  | 61 | +    ) -> StateMap[str]: | 
|  | 62 | +        """ | 
|  | 63 | +        Gets the room state at the event (ie not including the event if the event is a | 
|  | 64 | +        state event). | 
|  | 65 | +
 | 
|  | 66 | +        Args: | 
|  | 67 | +            state_filter: specifies the type of state event to fetch from DB, example: | 
|  | 68 | +            EventTypes.JoinRules | 
|  | 69 | +        """ | 
|  | 70 | +        pass | 
|  | 71 | + | 
|  | 72 | + | 
| 29 | 73 | @attr.s(slots=True, auto_attribs=True) | 
| 30 |  | -class EventContext: | 
|  | 74 | +class EventContext(UnpersistedEventContextBase): | 
| 31 | 75 |     """ | 
| 32 | 76 |     Holds information relevant to persisting an event | 
| 33 | 77 | 
 | 
| @@ -77,9 +121,6 @@ class EventContext: | 
| 77 | 121 |         delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` | 
| 78 | 122 |             and ``state_group``. | 
| 79 | 123 | 
 | 
| 80 |  | -        app_service: If this event is being sent by a (local) application service, that | 
| 81 |  | -            app service. | 
| 82 |  | -
 | 
| 83 | 124 |         partial_state: if True, we may be storing this event with a temporary, | 
| 84 | 125 |             incomplete state. | 
| 85 | 126 |     """ | 
| @@ -122,6 +163,9 @@ def for_outlier( | 
| 122 | 163 |         """Return an EventContext instance suitable for persisting an outlier event""" | 
| 123 | 164 |         return EventContext(storage=storage) | 
| 124 | 165 | 
 | 
|  | 166 | +    async def persist(self, event: EventBase) -> "EventContext": | 
|  | 167 | +        return self | 
|  | 168 | + | 
| 125 | 169 |     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: | 
| 126 | 170 |         """Converts self to a type that can be serialized as JSON, and then | 
| 127 | 171 |         deserialized by `deserialize` | 
| @@ -254,6 +298,128 @@ async def get_prev_state_ids( | 
| 254 | 298 |         ) | 
| 255 | 299 | 
 | 
| 256 | 300 | 
 | 
|  | 301 | +@attr.s(slots=True, auto_attribs=True) | 
|  | 302 | +class UnpersistedEventContext(UnpersistedEventContextBase): | 
|  | 303 | +    """ | 
|  | 304 | +    The event context holds information about the state groups for an event. It is important | 
|  | 305 | +    to remember that an event technically has two state groups: the state group before the | 
|  | 306 | +    event, and the state group after the event. If the event is not a state event, the state | 
|  | 307 | +    group will not change (ie the state group before the event will be the same as the state | 
|  | 308 | +    group after the event), but if it is a state event the state group before the event | 
|  | 309 | +    will differ from the state group after the event. | 
|  | 310 | +    This is a version of an EventContext before the new state group (if any) has been | 
|  | 311 | +    computed and stored. It contains information about the state before the event (which | 
|  | 312 | +    also may be the information after the event, if the event is not a state event). The | 
|  | 313 | +    UnpersistedEventContext must be converted into an EventContext by calling the method | 
|  | 314 | +    'persist' on it before it is suitable to be sent to the DB for processing. | 
|  | 315 | +
 | 
|  | 316 | +        state_group_after_event: | 
|  | 317 | +             The state group after the event. This will always be None until it is persisted. | 
|  | 318 | +             If the event is not a state event, this will be the same as | 
|  | 319 | +             state_group_before_event. | 
|  | 320 | +
 | 
|  | 321 | +        state_group_before_event: | 
|  | 322 | +            The ID of the state group representing the state of the room before this event. | 
|  | 323 | +
 | 
|  | 324 | +        state_delta_due_to_event: | 
|  | 325 | +            If the event is a state event, then this is the delta of the state between | 
|  | 326 | +             `state_group` and `state_group_before_event` | 
|  | 327 | +
 | 
|  | 328 | +        prev_group_for_state_group_before_event: | 
|  | 329 | +            If it is known, ``state_group_before_event``'s previous state group. | 
|  | 330 | +
 | 
|  | 331 | +        delta_ids_to_state_group_before_event: | 
|  | 332 | +             If ``prev_group_for_state_group_before_event`` is not None, the state delta | 
|  | 333 | +             between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``. | 
|  | 334 | +
 | 
|  | 335 | +        partial_state: | 
|  | 336 | +            Whether the event has partial state. | 
|  | 337 | +
 | 
|  | 338 | +        state_map_before_event: | 
|  | 339 | +            A map of the state before the event, i.e. the state at `state_group_before_event` | 
|  | 340 | +    """ | 
|  | 341 | + | 
|  | 342 | +    _storage: "StorageControllers" | 
|  | 343 | +    state_group_before_event: Optional[int] | 
|  | 344 | +    state_group_after_event: Optional[int] | 
|  | 345 | +    state_delta_due_to_event: Optional[dict] | 
|  | 346 | +    prev_group_for_state_group_before_event: Optional[int] | 
|  | 347 | +    delta_ids_to_state_group_before_event: Optional[StateMap[str]] | 
|  | 348 | +    partial_state: bool | 
|  | 349 | +    state_map_before_event: Optional[StateMap[str]] = None | 
|  | 350 | + | 
|  | 351 | +    async def get_prev_state_ids( | 
|  | 352 | +        self, state_filter: Optional["StateFilter"] = None | 
|  | 353 | +    ) -> StateMap[str]: | 
|  | 354 | +        """ | 
|  | 355 | +        Gets the room state map, excluding this event. | 
|  | 356 | +
 | 
|  | 357 | +        Args: | 
|  | 358 | +            state_filter: specifies the type of state event to fetch from DB | 
|  | 359 | +
 | 
|  | 360 | +        Returns: | 
|  | 361 | +            Maps a (type, state_key) to the event ID of the state event matching | 
|  | 362 | +            this tuple. | 
|  | 363 | +        """ | 
|  | 364 | +        if self.state_map_before_event: | 
|  | 365 | +            return self.state_map_before_event | 
|  | 366 | + | 
|  | 367 | +        assert self.state_group_before_event is not None | 
|  | 368 | +        return await self._storage.state.get_state_ids_for_group( | 
|  | 369 | +            self.state_group_before_event, state_filter | 
|  | 370 | +        ) | 
|  | 371 | + | 
|  | 372 | +    async def persist(self, event: EventBase) -> EventContext: | 
|  | 373 | +        """ | 
|  | 374 | +        Creates a full `EventContext` for the event, persisting any referenced state that | 
|  | 375 | +        has not yet been persisted. | 
|  | 376 | +
 | 
|  | 377 | +        Args: | 
|  | 378 | +             event: event that the EventContext is associated with. | 
|  | 379 | +
 | 
|  | 380 | +        Returns: An EventContext suitable for sending to the database with the event | 
|  | 381 | +        for persisting | 
|  | 382 | +        """ | 
|  | 383 | +        assert self.partial_state is not None | 
|  | 384 | + | 
|  | 385 | +        # If we have a full set of state for before the event but don't have a state | 
|  | 386 | +        # group for that state, we need to get one | 
|  | 387 | +        if self.state_group_before_event is None: | 
|  | 388 | +            assert self.state_map_before_event | 
|  | 389 | +            state_group_before_event = await self._storage.state.store_state_group( | 
|  | 390 | +                event.event_id, | 
|  | 391 | +                event.room_id, | 
|  | 392 | +                prev_group=self.prev_group_for_state_group_before_event, | 
|  | 393 | +                delta_ids=self.delta_ids_to_state_group_before_event, | 
|  | 394 | +                current_state_ids=self.state_map_before_event, | 
|  | 395 | +            ) | 
|  | 396 | +            self.state_group_before_event = state_group_before_event | 
|  | 397 | + | 
|  | 398 | +        # if the event isn't a state event the state group doesn't change | 
|  | 399 | +        if not self.state_delta_due_to_event: | 
|  | 400 | +            state_group_after_event = self.state_group_before_event | 
|  | 401 | + | 
|  | 402 | +        # otherwise if it is a state event we need to get a state group for it | 
|  | 403 | +        else: | 
|  | 404 | +            state_group_after_event = await self._storage.state.store_state_group( | 
|  | 405 | +                event.event_id, | 
|  | 406 | +                event.room_id, | 
|  | 407 | +                prev_group=self.state_group_before_event, | 
|  | 408 | +                delta_ids=self.state_delta_due_to_event, | 
|  | 409 | +                current_state_ids=None, | 
|  | 410 | +            ) | 
|  | 411 | + | 
|  | 412 | +        return EventContext.with_state( | 
|  | 413 | +            storage=self._storage, | 
|  | 414 | +            state_group=state_group_after_event, | 
|  | 415 | +            state_group_before_event=self.state_group_before_event, | 
|  | 416 | +            state_delta_due_to_event=self.state_delta_due_to_event, | 
|  | 417 | +            partial_state=self.partial_state, | 
|  | 418 | +            prev_group=self.state_group_before_event, | 
|  | 419 | +            delta_ids=self.state_delta_due_to_event, | 
|  | 420 | +        ) | 
|  | 421 | + | 
|  | 422 | + | 
| 257 | 423 | def _encode_state_dict( | 
| 258 | 424 |     state_dict: Optional[StateMap[str]], | 
| 259 | 425 | ) -> Optional[List[Tuple[str, str, str]]]: | 
|  | 
0 commit comments