From d9876b68a3f455599f60dca9ef54b3bf1f19aa77 Mon Sep 17 00:00:00 2001 From: chayleaf Date: Mon, 14 Aug 2023 19:50:04 +0700 Subject: [PATCH 1/2] Implement MSC3051 - A scalable relations format This allows using m.relations in place of m.relates_to for listing multiple relations for a single event. Signed-off-by: chayleaf --- changelog.d/16111.feature | 1 + synapse/api/filtering.py | 11 +- synapse/events/__init__.py | 53 +++-- synapse/handlers/message.py | 88 +++---- synapse/handlers/relations.py | 34 ++- synapse/push/bulk_push_rule_evaluator.py | 81 ++++--- synapse/storage/databases/main/cache.py | 8 +- synapse/storage/databases/main/events.py | 222 +++++++++--------- .../databases/main/events_bg_updates.py | 49 ++-- synapse/storage/schema/__init__.py | 5 +- .../01_allow_multiple_relations.sql.postgres | 17 ++ .../81/01_allow_multiple_relations.sql.sqlite | 17 ++ tests/rest/client/test_relations.py | 97 +++++++- 13 files changed, 409 insertions(+), 274 deletions(-) create mode 100644 changelog.d/16111.feature create mode 100644 synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.postgres create mode 100644 synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.sqlite diff --git a/changelog.d/16111.feature b/changelog.d/16111.feature new file mode 100644 index 000000000000..369e56e3caeb --- /dev/null +++ b/changelog.d/16111.feature @@ -0,0 +1 @@ +Implement MSC3051 - A scalable relation format. Contributed by @chayleaf. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 0995ecbe832a..82e85a4314c7 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, relations_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -408,18 +408,17 @@ def _check(self, event: FilterEvent) -> bool: labels = content.get(EventContentFields.LABELS, []) # Check if the event has a relation. - rel_type = None + rel_types: List[str] = [] if isinstance(event, EventBase): - relation = relation_from_event(event) - if relation: - rel_type = relation.rel_type + for relation in relations_from_event(event): + rel_types.append(relation.rel_type) field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, - "rel_types": lambda v: rel_type == v, + "rel_types": lambda v: v in rel_types, } result = self._check_fields(field_matchers) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 35257a3b1ba0..c7d33abf42e0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -638,32 +638,37 @@ class _EventRelation: aggregation_key: Optional[str] -def relation_from_event(event: EventBase) -> Optional[_EventRelation]: +def relations_from_event(event: EventBase) -> List[_EventRelation]: """ Attempt to parse relation information an event. Returns: - The event relation information, if it is valid. None, otherwise. + All valid event relation information. """ - relation = event.content.get("m.relates_to") - if not relation or not isinstance(relation, collections.abc.Mapping): - # No relation information. - return None - - # Relations must have a type and parent event ID. - rel_type = relation.get("rel_type") - if not isinstance(rel_type, str): - return None - - parent_id = relation.get("event_id") - if not isinstance(parent_id, str): - return None - - # Annotations have a key field. - aggregation_key = None - if rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.get("key") - if not isinstance(aggregation_key, str): - aggregation_key = None - - return _EventRelation(parent_id, rel_type, aggregation_key) + + relations = event.content.get("m.relations") + if not relations or not isinstance(relations, list): + relations = [event.content.get("m.relates_to")] + + ret: List[_EventRelation] = [] + for relation in relations: + if not relation or not isinstance(relation, collections.abc.Mapping): + continue + # Relations must have a type and parent event ID. + rel_type = relation.get("rel_type") + if not isinstance(rel_type, str): + continue + + parent_id = relation.get("event_id") + if not isinstance(parent_id, str): + continue + + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") + if not isinstance(aggregation_key, str): + aggregation_key = None + + ret.append(_EventRelation(parent_id, rel_type, aggregation_key)) + return ret diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d485f21e49f1..026792507616 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -47,7 +47,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, relations_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field @@ -1334,51 +1334,55 @@ async def _validate_event_relation(self, event: EventBase) -> None: SynapseError if the event is invalid. """ - relation = relation_from_event(event) - if not relation: - return - - parent_event = await self.store.get_event(relation.parent_id, allow_none=True) - if parent_event: - # And in the same room. - if parent_event.room_id != event.room_id: - raise SynapseError(400, "Relations must be in the same room") + for relation in relations_from_event(event): + if not relation: + continue - else: - # There must be some reason that the client knows the event exists, - # see if there are existing relations. If so, assume everything is fine. - if not await self.store.event_is_target_of_relation(relation.parent_id): - # Otherwise, the client can't know about the parent event! - raise SynapseError(400, "Can't send relation to unknown event") - - # If this event is an annotation then we check that that the sender - # can't annotate the same way twice (e.g. stops users from liking an - # event multiple times). - if relation.rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.aggregation_key - - if aggregation_key is None: - raise SynapseError(400, "Missing aggregation key") - - if len(aggregation_key) > 500: - raise SynapseError(400, "Aggregation key is too long") - - already_exists = await self.store.has_user_annotated_event( - relation.parent_id, event.type, aggregation_key, event.sender + parent_event = await self.store.get_event( + relation.parent_id, allow_none=True ) - if already_exists: - raise SynapseError( - 400, - "Can't send same reaction twice", - errcode=Codes.DUPLICATE_ANNOTATION, - ) + if parent_event: + # And in the same room. + if parent_event.room_id != event.room_id: + raise SynapseError(400, "Relations must be in the same room") - # Don't attempt to start a thread if the parent event is a relation. - elif relation.rel_type == RelationTypes.THREAD: - if await self.store.event_includes_relation(relation.parent_id): - raise SynapseError( - 400, "Cannot start threads from an event with a relation" + else: + # There must be some reason that the client knows the event exists, + # see if there are existing relations. If so, assume everything is fine. + if not await self.store.event_is_target_of_relation(relation.parent_id): + # Otherwise, the client can't know about the parent event! + raise SynapseError(400, "Can't send relation to unknown event") + + # If this event is an annotation then we check that that the sender + # can't annotate the same way twice (e.g. stops users from liking an + # event multiple times). + if relation.rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.aggregation_key + + if aggregation_key is None: + raise SynapseError(400, "Missing aggregation key") + + if len(aggregation_key) > 500: + raise SynapseError(400, "Aggregation key is too long") + + already_exists = await self.store.has_user_annotated_event( + relation.parent_id, event.type, aggregation_key, event.sender ) + if already_exists: + raise SynapseError( + 400, + "Can't send same reaction twice", + errcode=Codes.DUPLICATE_ANNOTATION, + ) + + # Don't attempt to start a thread if the parent event is a relation. + # XXX: should this be commented out alongside multiple relations being + # introduced? + elif relation.rel_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relation.parent_id): + raise SynapseError( + 400, "Cannot start threads from an event with a relation" + ) @measure_func("handle_new_client_event") async def handle_new_client_event( diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index db97f7aedee6..215d3171c96d 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -19,7 +19,7 @@ from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, relations_from_event from synapse.events.utils import SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace @@ -286,7 +286,7 @@ async def get_references_for_events( async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], - relations_by_id: Dict[str, str], + relations_by_id: Dict[str, List[str]], user_id: str, ignored_users: FrozenSet[str], ) -> Dict[str, _ThreadAggregation]: @@ -294,7 +294,7 @@ async def _get_threads_for_events( Args: events_by_id: A map of event_id to events to get aggregations for threads. - relations_by_id: A map of event_id to the relation type, if one exists + relations_by_id: A map of event_id to the relation types, if any exist for that event. user_id: The user requesting the bundled aggregations. ignored_users: The users ignored by the requesting user. @@ -432,28 +432,34 @@ async def get_bundled_aggregations( """ # De-duplicated events by ID to handle the same event requested multiple times. events_by_id = {} - # A map of event ID to the relation in that event, if there is one. - relations_by_id: Dict[str, str] = {} + # A map of event ID to the relations in that event, if there are any. + relations_by_id: Dict[str, List[str]] = {} for event in events: # State events do not get bundled aggregations. if event.is_state(): continue - relates_to = relation_from_event(event) - if relates_to: + valid = True + + for relates_to in relations_from_event(event): # An event which is a replacement (ie edit) or annotation (ie, # reaction) may not have any other event related to it. + # XXX: should this be removed alongside multiple relations being introduced? if relates_to.rel_type in ( RelationTypes.ANNOTATION, RelationTypes.REPLACE, ): - continue + valid = False + break # Track the event's relation information for later. - relations_by_id[event.event_id] = relates_to.rel_type + if event.event_id not in relations_by_id.keys(): + relations_by_id[event.event_id] = [] + relations_by_id[event.event_id].append(relates_to.rel_type) - # The event should get bundled aggregations. - events_by_id[event.event_id] = event + if valid: + # The event should get bundled aggregations. + events_by_id[event.event_id] = event # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} @@ -484,7 +490,11 @@ async def get_bundled_aggregations( # # We know that the latest event in a thread has a thread relation # (as that is what makes it part of the thread). - relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD + if latest_thread_event.event_id not in relations_by_id.keys(): + relations_by_id[latest_thread_event.event_id] = [] + relations_by_id[latest_thread_event.event_id].append( + RelationTypes.THREAD + ) async def _fetch_references() -> None: """Fetch any references to bundle with this event.""" diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 990c079c815b..4ac91b3d83bf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections.abc import logging from typing import ( TYPE_CHECKING, @@ -38,7 +39,7 @@ ) from synapse.api.room_versions import PushRuleRoomFlag from synapse.event_auth import auth_types_for_event, get_user_power_level -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, relations_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership @@ -87,9 +88,9 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False # Exclude edits. - relates_to = relation_from_event(event) - if relates_to and relates_to.rel_type == RelationTypes.REPLACE: - return False + for relates_to in relations_from_event(event): + if relates_to.rel_type == RelationTypes.REPLACE: + return False # Mark events that have a non-empty string body as unread. body = event.content.get("body") @@ -263,37 +264,42 @@ async def _related_events( """ related_events: Dict[str, Dict[str, JsonValue]] = {} if self._related_event_match_enabled: - related_event_id = event.content.get("m.relates_to", {}).get("event_id") - relation_type = event.content.get("m.relates_to", {}).get("rel_type") - if related_event_id is not None and relation_type is not None: - related_event = await self.store.get_event( - related_event_id, allow_none=True - ) - if related_event is not None: - related_events[relation_type] = _flatten_dict(related_event) - - reply_event_id = ( - event.content.get("m.relates_to", {}) - .get("m.in_reply_to", {}) - .get("event_id") - ) - - # convert replies to pseudo relations - if reply_event_id is not None: - related_event = await self.store.get_event( - reply_event_id, allow_none=True - ) - - if related_event is not None: - related_events["m.in_reply_to"] = _flatten_dict(related_event) - - # indicate that this is from a fallback relation. - if relation_type == "m.thread" and event.content.get( - "m.relates_to", {} - ).get("is_falling_back", False): - related_events["m.in_reply_to"][ - "im.vector.is_falling_back" - ] = "" + relations = event.content.get("m.relations") + if not relations or not isinstance(relations, list): + relations = [event.content.get("m.relates_to")] + + for relation in relations: + if not relation or not isinstance(relation, collections.abc.Mapping): + continue + + related_event_id = relation.get("event_id") + relation_type = relation.get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = relation.get("m.in_reply_to", {}).get("event_id") + + # convert replies to pseudo relations + # XXX: does this need any changes for multiple relations? + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" return related_events @@ -353,11 +359,10 @@ async def _action_for_event_by_user( event, context, event_id_to_event ) - # Find the event's thread ID. - relation = relation_from_event(event) # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE - if relation: + # Find the event's thread ID. + for relation in relations_from_event(event): # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2fbd389c7168..44a0c3b0a42f 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -196,7 +196,7 @@ def process_replication_rows( row.type, row.state_key, row.redacts, - row.relates_to, + [row.relates_to], backfilled=True, ) elif stream_name == CachesStream.NAME: @@ -252,7 +252,7 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data.type, data.state_key, data.redacts, - data.relates_to, + [data.relates_to] if type(data.relates_to) is str else [], backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: @@ -275,7 +275,7 @@ def _invalidate_caches_for_event( etype: str, state_key: Optional[str], redacts: Optional[str], - relates_to: Optional[str], + relations: List[str], backfilled: bool, ) -> None: # XXX: If you add something to this function make sure you add it to @@ -329,7 +329,7 @@ def _invalidate_caches_for_event( "get_forgotten_rooms_for_user", (state_key,) ) - if relates_to: + for relates_to in relations: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c1353b18c1cd..b394bad43710 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -37,7 +37,7 @@ from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, relations_from_event from synapse.events.snapshot import EventContext from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause @@ -395,10 +395,9 @@ def _persist_events_txn( if event.redacts: self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) - relates_to = None - relation = relation_from_event(event) - if relation: - relates_to = relation.parent_id + relations = list( + map(lambda rel: rel.parent_id, relations_from_event(event)) + ) assert event.internal_metadata.stream_ordering is not None txn.call_after( @@ -409,7 +408,7 @@ def _persist_events_txn( event.type, getattr(event, "state_key", None), event.redacts, - relates_to, + relations, backfilled=False, ) @@ -1877,49 +1876,45 @@ def _handle_event_relations( txn: The current database transaction. event: The event which might have relations. """ - relation = relation_from_event(event) - if not relation: - # No relation, nothing to do. - return - - self.db_pool.simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": relation.parent_id, - "relation_type": relation.rel_type, - "aggregation_key": relation.aggregation_key, - }, - ) + for relation in relations_from_event(event): + self.db_pool.simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": relation.parent_id, + "relation_type": relation.rel_type, + "aggregation_key": relation.aggregation_key, + }, + ) - if relation.rel_type == RelationTypes.THREAD: - # Upsert into the threads table, but only overwrite the value if the - # new event is of a later topological order OR if the topological - # ordering is equal, but the stream ordering is later. - sql = """ - INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT (room_id, thread_id) - DO UPDATE SET - latest_event_id = excluded.latest_event_id, - topological_ordering = excluded.topological_ordering, - stream_ordering = excluded.stream_ordering - WHERE - threads.topological_ordering <= excluded.topological_ordering AND - threads.stream_ordering < excluded.stream_ordering - """ + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ - txn.execute( - sql, - ( - event.room_id, - relation.parent_id, - event.event_id, - event.depth, - event.internal_metadata.stream_ordering, - ), - ) + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) def _handle_redact_relations( self, txn: LoggingTransaction, room_id: str, redacted_event_id: str @@ -1932,80 +1927,83 @@ def _handle_redact_relations( room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ - - # Fetch the relation of the event being redacted. - row = self.db_pool.simple_select_one_txn( + x = self.db_pool.simple_select_one_txn + # Fetch the relations of the event being redacted. + rows = self.db_pool.simple_select_many_txn( txn, table="event_relations", - keyvalues={"event_id": redacted_event_id}, + column="event_id", + iterable={redacted_event_id}, + keyvalues={}, retcols=("relates_to_id", "relation_type"), - allow_none=True, - ) - # Nothing to do if no relation is found. - if row is None: - return - - redacted_relates_to = row["relates_to_id"] - rel_type = row["relation_type"] - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) - - # Any relation information for the related event must be cleared. - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) ) - if rel_type == RelationTypes.REFERENCE: - self.store._invalidate_cache_and_stream( - txn, self.store.get_references_for_event, (redacted_relates_to,) - ) - if rel_type == RelationTypes.REPLACE: - self.store._invalidate_cache_and_stream( - txn, self.store.get_applicable_edit, (redacted_relates_to,) - ) - if rel_type == RelationTypes.THREAD: - self.store._invalidate_cache_and_stream( - txn, self.store.get_thread_summary, (redacted_relates_to,) - ) - self.store._invalidate_cache_and_stream( - txn, self.store.get_thread_participated, (redacted_relates_to,) + for row in rows: + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) + + # Any relation information for the related event must be cleared. self.store._invalidate_cache_and_stream( - txn, self.store.get_threads, (room_id,) + txn, self.store.get_relations_for_event, (redacted_relates_to,) ) - - # Find the new latest event in the thread. - sql = """ - SELECT event_id, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND relation_type = ? - ORDER BY topological_ordering DESC, stream_ordering DESC - LIMIT 1 - """ - txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) - - # If a latest event is found, update the threads table, this might - # be the same current latest event (if an earlier event in the thread - # was redacted). - latest_event_row = txn.fetchone() - if latest_event_row: - self.db_pool.simple_upsert_txn( - txn, - table="threads", - keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, - values={ - "latest_event_id": latest_event_row[0], - "topological_ordering": latest_event_row[1], - "stream_ordering": latest_event_row[2], - }, + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) ) - - # Otherwise, delete the thread: it no longer exists. - else: - self.db_pool.simple_delete_one_txn( - txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + if rel_type == RelationTypes.REPLACE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_applicable_edit, (redacted_relates_to,) + ) + if rel_type == RelationTypes.THREAD: + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_summary, (redacted_relates_to,) + ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_threads, (room_id,) + ) + + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={ + "room_id": room_id, + "thread_id": redacted_relates_to, + }, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, + table="threads", + keyvalues={"thread_id": redacted_relates_to}, + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index daef3685b09a..9576b769aad5 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1178,28 +1178,31 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int: ) continue - # If there's no relation, skip! - relates_to = event_json["content"].get("m.relates_to") - if not relates_to or not isinstance(relates_to, dict): - continue + relations = event_json["content"].get("m.relations") + if not relations or not isinstance(relations, list): + relations = [event_json["content"].get("m.relates_to")] - # If the relation type or parent event ID is not a string, skip it. - # - # Do not consider relation types that have existed for a long time, - # since they will already be listed in the `event_relations` table. - rel_type = relates_to.get("rel_type") - if not isinstance(rel_type, str) or rel_type in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - continue + for relates_to in relations: + if not relates_to or not isinstance(relates_to, dict): + continue - parent_id = relates_to.get("event_id") - if not isinstance(parent_id, str): - continue + # If the relation type or parent event ID is not a string, skip it. + # + # Do not consider relation types that have existed for a long time, + # since they will already be listed in the `event_relations` table. + rel_type = relates_to.get("rel_type") + if not isinstance(rel_type, str) or rel_type in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + continue + + parent_id = relates_to.get("event_id") + if not isinstance(parent_id, str): + continue - relations_to_insert.append((event_id, parent_id, rel_type)) + relations_to_insert.append((event_id, parent_id, rel_type)) # Insert the missing data, note that we upsert here in case the event # has already been processed. @@ -1207,10 +1210,10 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int: self.db_pool.simple_upsert_many_txn( txn=txn, table="event_relations", - key_names=("event_id",), - key_values=[(r[0],) for r in relations_to_insert], - value_names=("relates_to_id", "relation_type"), - value_values=[r[1:] for r in relations_to_insert], + key_names=("event_id", "relates_to_id", "relation_type"), + key_values=relations_to_insert, + value_names=(), + value_values=[], ) # Iterate the parent IDs and invalidate caches. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 7de9949a5b7f..969db21109af 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 80 # remember to update the list below when updating +SCHEMA_VERSION = 81 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -113,6 +113,9 @@ Changes in SCHEMA_VERSION = 80 - The event_txn_id_device_id is always written to for new events. + +Changes in SCHEMA_VERSION = 81 + - Allow multiple entries in event_relations for a single event """ diff --git a/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.postgres b/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.postgres new file mode 100644 index 000000000000..e159986b249c --- /dev/null +++ b/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.postgres @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX event_relations_id; +CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id, relates_to_id, relation_type); diff --git a/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.sqlite b/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.sqlite new file mode 100644 index 000000000000..e0f307a4145e --- /dev/null +++ b/synapse/storage/schema/main/delta/81/01_allow_multiple_relations.sql.sqlite @@ -0,0 +1,17 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX event_relations_id; +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id, relates_to_id, relation_type); diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 75439416c175..2f7b24782b3e 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -71,6 +71,38 @@ def _create_user(self, localpart: str) -> Tuple[str, str]: return user_id, access_token + def _send_content( + self, + content: dict, + event_type: str, + access_token: Optional[str] = None, + parent_id: Optional[str] = None, + expected_response_code: int = 200, + ) -> FakeChannel: + """Helper function to send an event + + Args: + content: The content of the created event. + event_type: The type of the event to create + access_token: The access token used to send the relation, defaults + to `self.user_token` + parent_id: The event_id this relation relates to. If None, then self.parent_id + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, + access_token=access_token, + ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) + return channel + def _send_relation( self, relation_type: str, @@ -96,11 +128,7 @@ def _send_relation( Returns: FakeChannel """ - if not access_token: - access_token = self.user_token - original_id = parent_id if parent_id else self.parent_id - if content is None: content = {} content["m.relates_to"] = { @@ -109,15 +137,9 @@ def _send_relation( } if key is not None: content["m.relates_to"]["key"] = key - - channel = self.make_request( - "POST", - f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", - content, - access_token=access_token, + return self._send_content( + content, event_type, access_token, parent_id, expected_response_code ) - self.assertEqual(expected_response_code, channel.code, channel.json_body) - return channel def _get_related_events(self) -> List[str]: """ @@ -184,6 +206,56 @@ def test_send_relation(self) -> None: channel.json_body, ) + def test_send_multiple_relations(self) -> None: + """Tests that sending multiple relations works.""" + channel = self._send_content( + { + "m.relations": [ + { + "event_id": self.parent_id, + "key": "👍", + "rel_type": "m.reaction", + }, + { + "event_id": self.parent_id, + "key": "👎", + "rel_type": "m.reaction", + }, + ], + }, + "m.reaction", + ) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{event_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relations": [ + { + "event_id": self.parent_id, + "key": "👍", + "rel_type": "m.reaction", + }, + { + "event_id": self.parent_id, + "key": "👎", + "rel_type": "m.reaction", + }, + ], + }, + }, + channel.json_body, + ) + def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" self._send_relation( @@ -238,6 +310,7 @@ def test_deny_double_react(self) -> None: def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" + # XXX: is it valid with m.relations? channel = self._send_relation( RelationTypes.THREAD, "m.room.message", From 53779e8f216ee59cc0ddb4cccfa3aa50ad5eda45 Mon Sep 17 00:00:00 2001 From: chayleaf Date: Mon, 14 Aug 2023 22:09:41 +0700 Subject: [PATCH 2/2] Add failing test Signed-off-by: chayleaf --- tests/rest/client/test_relations.py | 39 +++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2f7b24782b3e..631a442c4bde 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1173,6 +1173,7 @@ def _test_bundled_aggregations( def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" + print(event_json) relations_dict = event_json["unsigned"].get("m.relations") # Ensure the fields are as expected. @@ -1253,6 +1254,44 @@ def assert_annotations(bundled_aggregations: JsonDict) -> None: self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + def test_msc3051_references(self) -> None: + """ + Test that messages with multiple events and references get correctly bundled. + """ + channel = self._send_content( + { + "m.relations": [ + { + "event_id": self.parent_id, + "rel_type": RelationTypes.REFERENCE, + }, + ], + }, + "m.room.test", + ) + reply_1 = channel.json_body["event_id"] + + channel = self._send_content( + { + "m.relations": [ + { + "event_id": self.parent_id, + "rel_type": RelationTypes.REFERENCE, + }, + ], + }, + "m.room.test", + ) + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + def test_thread(self) -> None: """ Test that threads get correctly bundled.