1313# limitations under the License.
1414
1515import logging
16- from typing import TYPE_CHECKING , Dict , Iterable , List , Optional , Tuple , Union , cast
16+ from typing import (
17+ TYPE_CHECKING ,
18+ Collection ,
19+ Dict ,
20+ Iterable ,
21+ List ,
22+ Optional ,
23+ Tuple ,
24+ Union ,
25+ cast ,
26+ )
1727
1828import attr
1929from frozendict import frozendict
2030
21- from synapse .api .constants import EventTypes , RelationTypes
31+ from synapse .api .constants import RelationTypes
2232from synapse .events import EventBase
2333from synapse .storage ._base import SQLBaseStore
2434from synapse .storage .database import (
2838 make_in_list_sql_clause ,
2939)
3040from synapse .storage .databases .main .stream import generate_pagination_where_clause
41+ from synapse .storage .engines import PostgresEngine
3142from synapse .storage .relations import (
3243 AggregationPaginationToken ,
3344 PaginationChunk ,
3445 RelationPaginationToken ,
3546)
3647from synapse .types import JsonDict
37- from synapse .util .caches .descriptors import cached
48+ from synapse .util .caches .descriptors import cached , cachedList
3849
3950if TYPE_CHECKING :
4051 from synapse .server import HomeServer
@@ -340,20 +351,24 @@ def _get_aggregation_groups_for_event_txn(
340351 )
341352
342353 @cached ()
343- async def get_applicable_edit (
344- self , event_id : str , room_id : str
345- ) -> Optional [EventBase ]:
354+ def get_applicable_edit (self , event_id : str ) -> Optional [EventBase ]:
355+ raise NotImplementedError ()
356+
357+ @cachedList (cached_method_name = "get_applicable_edit" , list_name = "event_ids" )
358+ async def _get_applicable_edits (
359+ self , event_ids : Collection [str ]
360+ ) -> Dict [str , Optional [EventBase ]]:
346361 """Get the most recent edit (if any) that has happened for the given
347- event .
362+ events .
348363
349364 Correctly handles checking whether edits were allowed to happen.
350365
351366 Args:
352- event_id: The original event ID
353- room_id: The original event's room ID
367+ event_ids: The original event IDs
354368
355369 Returns:
356- The most recent edit, if any.
370+ A map of the most recent edit for each event. If there are no edits,
371+ the event will map to None.
357372 """
358373
359374 # We only allow edits for `m.room.message` events that have the same sender
@@ -362,37 +377,67 @@ async def get_applicable_edit(
362377
363378 # Fetches latest edit that has the same type and sender as the
364379 # original, and is an `m.room.message`.
365- sql = """
366- SELECT edit.event_id FROM events AS edit
367- INNER JOIN event_relations USING (event_id)
368- INNER JOIN events AS original ON
369- original.event_id = relates_to_id
370- AND edit.type = original.type
371- AND edit.sender = original.sender
372- WHERE
373- relates_to_id = ?
374- AND relation_type = ?
375- AND edit.room_id = ?
376- AND edit.type = 'm.room.message'
377- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
378- LIMIT 1
379- """
380+ if isinstance (self .database_engine , PostgresEngine ):
381+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
382+ # so ordering by origin server ts + event ID desc will ensure we get
383+ # the latest edit.
384+ sql = """
385+ SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
386+ INNER JOIN event_relations USING (event_id)
387+ INNER JOIN events AS original ON
388+ original.event_id = relates_to_id
389+ AND edit.type = original.type
390+ AND edit.sender = original.sender
391+ AND edit.room_id = original.room_id
392+ WHERE
393+ %s
394+ AND relation_type = ?
395+ AND edit.type = 'm.room.message'
396+ ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
397+ """
398+ else :
399+ # SQLite uses a simplified query which returns all edits for an
400+ # original event. The results are then de-duplicated when turned into
401+ # a dict. Due to the chosen ordering, the latest edit stomps on
402+ # earlier edits.
403+ sql = """
404+ SELECT original.event_id, edit.event_id FROM events AS edit
405+ INNER JOIN event_relations USING (event_id)
406+ INNER JOIN events AS original ON
407+ original.event_id = relates_to_id
408+ AND edit.type = original.type
409+ AND edit.sender = original.sender
410+ AND edit.room_id = original.room_id
411+ WHERE
412+ %s
413+ AND relation_type = ?
414+ AND edit.type = 'm.room.message'
415+ ORDER by edit.origin_server_ts, edit.event_id
416+ """
380417
381- def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [str ]:
382- txn .execute (sql , (event_id , RelationTypes .REPLACE , room_id ))
383- row = txn .fetchone ()
384- if row :
385- return row [0 ]
386- return None
418+ def _get_applicable_edits_txn (txn : LoggingTransaction ) -> Dict [str , str ]:
419+ clause , args = make_in_list_sql_clause (
420+ txn .database_engine , "relates_to_id" , event_ids
421+ )
422+ args .append (RelationTypes .REPLACE )
387423
388- edit_id = await self .db_pool .runInteraction (
389- "get_applicable_edit" , _get_applicable_edit_txn
424+ txn .execute (sql % (clause ,), args )
425+ return dict (cast (Iterable [Tuple [str , str ]], txn .fetchall ()))
426+
427+ edit_ids = await self .db_pool .runInteraction (
428+ "get_applicable_edits" , _get_applicable_edits_txn
390429 )
391430
392- if not edit_id :
393- return None
431+ edits = await self .get_events (edit_ids .values ()) # type: ignore[attr-defined]
394432
395- return await self .get_event (edit_id , allow_none = True ) # type: ignore[attr-defined]
433+ # Map to the original event IDs to the edit events.
434+ #
435+ # There might not be an edit event due to there being no edits or
436+ # due to the event not being known, either case is treated the same.
437+ return {
438+ original_event_id : edits .get (edit_ids .get (original_event_id ))
439+ for original_event_id in event_ids
440+ }
396441
397442 @cached ()
398443 async def get_thread_summary (
@@ -612,9 +657,6 @@ async def _get_bundled_aggregation_for_event(
612657 The bundled aggregations for an event, if bundled aggregations are
613658 enabled and the event can have bundled aggregations.
614659 """
615- # State events and redacted events do not get bundled aggregations.
616- if event .is_state () or event .internal_metadata .is_redacted ():
617- return None
618660
619661 # Do not bundle aggregations for an event which represents an edit or an
620662 # annotation. It does not make sense for them to have related events.
@@ -642,13 +684,6 @@ async def _get_bundled_aggregation_for_event(
642684 if references .chunk :
643685 aggregations .references = references .to_dict ()
644686
645- edit = None
646- if event .type == EventTypes .Message :
647- edit = await self .get_applicable_edit (event_id , room_id )
648-
649- if edit :
650- aggregations .replace = edit
651-
652687 # If this event is the start of a thread, include a summary of the replies.
653688 if self ._msc3440_enabled :
654689 thread_count , latest_thread_event = await self .get_thread_summary (
@@ -668,9 +703,7 @@ async def _get_bundled_aggregation_for_event(
668703 return aggregations
669704
670705 async def get_bundled_aggregations (
671- self ,
672- events : Iterable [EventBase ],
673- user_id : str ,
706+ self , events : Iterable [EventBase ], user_id : str
674707 ) -> Dict [str , BundledAggregations ]:
675708 """Generate bundled aggregations for events.
676709
@@ -683,13 +716,28 @@ async def get_bundled_aggregations(
683716 events may have bundled aggregations in the results.
684717 """
685718
686- # TODO Parallelize.
687- results = {}
719+ # State events and redacted events do not get bundled aggregations.
720+ events = [
721+ event
722+ for event in events
723+ if not event .is_state () and not event .internal_metadata .is_redacted ()
724+ ]
725+
726+ # event ID -> bundled aggregation in non-serialized form.
727+ results : Dict [str , BundledAggregations ] = {}
728+
729+ # Fetch other relations per event.
688730 for event in events :
689731 event_result = await self ._get_bundled_aggregation_for_event (event , user_id )
690732 if event_result :
691733 results [event .event_id ] = event_result
692734
735+ # Fetch any edits.
736+ event_ids = [event .event_id for event in events ]
737+ edits = await self ._get_applicable_edits (event_ids )
738+ for event_id , edit in edits .items ():
739+ results .setdefault (event_id , BundledAggregations ()).replace = edit
740+
693741 return results
694742
695743
0 commit comments