diff --git a/changelog.d/18756.misc b/changelog.d/18756.misc new file mode 100644 index 00000000000..c4353f776ed --- /dev/null +++ b/changelog.d/18756.misc @@ -0,0 +1 @@ +Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b832c2f6a15..ec4d707b7b0 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,12 @@ class Codes(str, Enum): # Part of MSC4155 INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED" + # Part of MSC4306: Thread Subscriptions + MSC4306_CONFLICTING_UNSUBSCRIPTION = ( + "IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION" + ) + MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD" + class CodeMessageException(RuntimeError): """An exception with integer code, a message string attributes and optional headers. diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index 79e4d6040dd..bda43429491 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -1,9 +1,15 @@ import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Optional -from synapse.api.errors import AuthError, NotFoundError -from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription -from synapse.types import UserID +from synapse.api.constants import RelationTypes +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import relation_from_event +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscription, +) +from synapse.types import EventOrderings, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -55,42 +61,79 @@ async def subscribe_user_to_thread( room_id: str, thread_root_event_id: str, *, - automatic: bool, + automatic_event_id: Optional[str], ) -> Optional[int]: """Sets or updates a user's subscription settings for a specific thread root. Args: requester_user_id: The ID of the user whose settings are being updated. thread_root_event_id: The event ID of the thread root. - automatic: whether the user was subscribed by an automatic decision by - their client. + automatic_event_id: if the user was subscribed by an automatic decision by + their client, the event ID that caused this. Returns: The stream ID for this update, if the update isn't no-opped. Raises: NotFoundError if the user cannot access the thread root event, or it isn't - known to this homeserver. + known to this homeserver. Ditto for the automatic cause event if supplied. + + SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event + but user cannot access the event. + + SynapseError(409, M_SKIPPED): if client requested an automatic subscription + but it was skipped because the cause event is logically later than an unsubscription. """ # First check that the user can access the thread root event # and that it exists try: - event = await self.event_handler.get_event( + thread_root_event = await self.event_handler.get_event( user_id, room_id, thread_root_event_id ) - if event is None: + if thread_root_event is None: raise NotFoundError("No such thread root") except AuthError: logger.info("rejecting thread subscriptions change (thread not accessible)") raise NotFoundError("No such thread root") - return await self.store.subscribe_user_to_thread( + if automatic_event_id: + autosub_cause_event = await self.event_handler.get_event( + user_id, room_id, automatic_event_id + ) + if autosub_cause_event is None: + raise NotFoundError("Automatic subscription event not found") + relation = relation_from_event(autosub_cause_event) + if ( + relation is None + or relation.rel_type != RelationTypes.THREAD + or relation.parent_id != thread_root_event_id + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Automatic subscription must use an event in the thread", + errcode=Codes.MSC4306_NOT_IN_THREAD, + ) + + automatic_event_orderings = EventOrderings.from_event(autosub_cause_event) + else: + automatic_event_orderings = None + + outcome = await self.store.subscribe_user_to_thread( user_id.to_string(), - event.room_id, + room_id, thread_root_event_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) + if isinstance(outcome, AutomaticSubscriptionConflicted): + raise SynapseError( + HTTPStatus.CONFLICT, + "Automatic subscription obsoleted by an unsubscription request.", + errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + ) + + return outcome + async def unsubscribe_user_from_thread( self, user_id: UserID, room_id: str, thread_root_event_id: str ) -> Optional[int]: diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9694fff4fee..ec7e935d6a3 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -739,7 +739,7 @@ class ThreadSubscriptionsStreamRow: NAME = "thread_subscriptions" ROW_TYPE = ThreadSubscriptionsStreamRow - def __init__(self, hs: Any): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -751,7 +751,7 @@ async def _update_function( self, instance_name: str, from_token: int, to_token: int, limit: int ) -> StreamUpdateResult: updates = await self.store.get_updated_thread_subscriptions( - from_token, to_token, limit + from_id=from_token, to_id=to_token, limit=limit ) rows = [ ( diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index eb724500b20..4e7b5d06dbe 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,7 +1,6 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse._pydantic_compat import StrictBool from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -12,6 +11,7 @@ from synapse.rest.client._base import client_patterns from synapse.types import JsonDict, RoomID from synapse.types.rest import RequestBodyModel +from synapse.util.pydantic_models import AnyEventId if TYPE_CHECKING: from synapse.server import HomeServer @@ -32,7 +32,12 @@ def __init__(self, hs: "HomeServer"): self.handler = hs.get_thread_subscriptions_handler() class PutBody(RequestBodyModel): - automatic: StrictBool + automatic: Optional[AnyEventId] + """ + If supplied, the event ID of an event giving rise to this automatic subscription. + + If omitted, this subscription is a manual subscription. + """ async def on_GET( self, request: SynapseRequest, room_id: str, thread_root_id: str @@ -63,15 +68,15 @@ async def on_PUT( raise SynapseError( HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM ) - requester = await self.auth.get_user_by_req(request) - body = parse_and_validate_json_object_from_request(request, self.PutBody) + requester = await self.auth.get_user_by_req(request) + await self.handler.subscribe_user_to_thread( requester.user, room_id, thread_root_id, - automatic=body.automatic, + automatic_event_id=body.automatic, ) return HTTPStatus.OK, {} diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 4933224f0f6..a99ef430717 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -14,7 +14,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Iterable, List, Optional, @@ -33,6 +32,7 @@ ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import EventOrderings from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -50,6 +50,14 @@ class ThreadSubscription: """ +class AutomaticSubscriptionConflicted: + """ + Marker return value to signal that an automatic subscription was skipped, + because it conflicted with an unsubscription that we consider to have + been made later than the event causing the automatic subscription. + """ + + class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -101,61 +109,172 @@ def process_replication_position( self._thread_subscriptions_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) + @staticmethod + def _should_skip_autosubscription_after_unsubscription( + *, + autosub: EventOrderings, + unsubscribed_at: EventOrderings, + ) -> bool: + """ + Returns whether an automatic subscription occurring *after* an unsubscription + should be skipped, because the unsubscription already 'acknowledges' the event + causing the automatic subscription (the cause event). + + To determine *after*, we use `stream_ordering` unless the event is backfilled + (negative `stream_ordering`) and fallback to topological ordering. + + Args: + autosub: the stream_ordering and topological_ordering of the cause event + unsubscribed_at: + the maximum stream ordering and the maximum topological ordering at the time of unsubscription + + Returns: + True if the automatic subscription should be skipped + """ + # For normal rooms, these two orderings should be positive, because + # they don't refer to a specific event but rather the maximum at the + # time of unsubscription. + # + # However, for rooms that have never been joined and that are being peeked at, + # we might not have a single non-backfilled event and therefore the stream + # ordering might be negative, so we don't assert this case. + assert unsubscribed_at.topological > 0 + + unsubscribed_at_backfilled = unsubscribed_at.stream < 0 + if ( + not unsubscribed_at_backfilled + and unsubscribed_at.stream >= autosub.stream > 0 + ): + # non-backfilled events: the unsubscription is later according to + # the stream + return True + + if autosub.stream < 0: + # the auto-subscription cause event was backfilled, so fall back to + # topological ordering + if unsubscribed_at.topological >= autosub.topological: + return True + + return False + async def subscribe_user_to_thread( - self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool - ) -> Optional[int]: + self, + user_id: str, + room_id: str, + thread_root_event_id: str, + *, + automatic_event_orderings: Optional[EventOrderings], + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: """Updates a user's subscription settings for a specific thread root. If no change would be made to the subscription, does not produce any database change. + Case-by-case: + - if we already have an automatic subscription: + - new automatic subscriptions will be no-ops (no database write), + - new manual subscriptions will overwrite the automatic subscription + - if we already have a manual subscription: + we don't update (no database write) in either case, because: + - the existing manual subscription wins over a new automatic subscription request + - there would be no need to write a manual subscription because we already have one + Args: user_id: The ID of the user whose settings are being updated. room_id: The ID of the room the thread root belongs to. thread_root_event_id: The event ID of the thread root. - automatic: Whether the subscription was performed automatically by the user's client. - Only `False` will overwrite an existing value of automatic for a subscription row. + automatic_event_orderings: + Value depends on whether the subscription was performed automatically by the user's client. + For manual subscriptions: None. + For automatic subscriptions: the orderings of the event. Returns: - The stream ID for this update, if the update isn't no-opped. + If a subscription is made: (int) the stream ID for this update. + If a subscription already exists and did not need to be updated: None + If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted """ assert self._can_write_to_thread_subscriptions - def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]: - already_automatic = self.db_pool.simple_select_one_onecol_txn( + def _subscribe_user_to_thread_txn( + txn: LoggingTransaction, + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + requested_automatic = automatic_event_orderings is not None + + row = self.db_pool.simple_select_one_txn( txn, table="thread_subscriptions", keyvalues={ "user_id": user_id, "event_id": thread_root_event_id, "room_id": room_id, - "subscribed": True, }, - retcol="automatic", + retcols=( + "subscribed", + "automatic", + "unsubscribed_at_stream_ordering", + "unsubscribed_at_topological_ordering", + ), allow_none=True, ) - if already_automatic is None: - already_subscribed = False - already_automatic = True - else: - already_subscribed = True - # convert int (SQLite bool) to Python bool - already_automatic = bool(already_automatic) - - if already_subscribed and already_automatic == automatic: - # there is nothing we need to do here + if row is None: + # We have never subscribed before, simply insert the row and finish + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_insert_txn( + txn, + table="thread_subscriptions", + values={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, + ) + txn.call_after( + self.get_subscription_for_thread.invalidate, + (user_id, room_id, thread_root_event_id), + ) + return stream_id + + # we already have either a subscription or a prior unsubscription here + ( + subscribed, + already_automatic, + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ) = row + + if subscribed and (not already_automatic or requested_automatic): + # we are already subscribed and the current subscription state + # is good enough (either we already have a manual subscription, + # or we requested an automatic subscription) + # In that case, nothing to change here. + # (See docstring for case-by-case explanation) return None - stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) - - values: Dict[str, Optional[Union[bool, int, str]]] = { - "subscribed": True, - "stream_id": stream_id, - "instance_name": self._instance_name, - "automatic": already_automatic and automatic, - } + if not subscribed and requested_automatic: + assert automatic_event_orderings is not None + # we previously unsubscribed and we are now automatically subscribing + # Check whether the new autosubscription should be skipped + if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription( + autosub=automatic_event_orderings, + unsubscribed_at=EventOrderings( + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ), + ): + # skip the subscription + return AutomaticSubscriptionConflicted() + + # At this point: we have now finished checking that we need to make + # a subscription, updating the current row. - self.db_pool.simple_upsert_txn( + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_update_txn( txn, table="thread_subscriptions", keyvalues={ @@ -163,9 +282,15 @@ def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]: "event_id": thread_root_event_id, "room_id": room_id, }, - values=values, + updatevalues={ + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, ) - txn.call_after( self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), @@ -214,6 +339,21 @@ def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]: stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + # Find the maximum stream ordering and topological ordering of the room, + # which we then store against this unsubscription so we can skip future + # automatic subscriptions that are caused by an event logically earlier + # than this unsubscription. + txn.execute( + """ + SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events + WHERE room_id = ? + """, + (room_id,), + ) + ord_row = txn.fetchone() + assert ord_row is not None + max_stream_ordering, max_topological_ordering = ord_row + self.db_pool.simple_update_txn( txn, table="thread_subscriptions", @@ -227,6 +367,8 @@ def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]: "subscribed": False, "stream_id": stream_id, "instance_name": self._instance_name, + "unsubscribed_at_stream_ordering": max_stream_ordering, + "unsubscribed_at_topological_ordering": max_topological_ordering, }, ) @@ -316,7 +458,7 @@ def get_max_thread_subscriptions_stream_id(self) -> int: return self._thread_subscriptions_id_gen.get_current_token() async def get_updated_thread_subscriptions( - self, from_id: int, to_id: int, limit: int + self, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: """Get updates to thread subscriptions between two stream IDs. @@ -349,7 +491,7 @@ def get_updated_thread_subscriptions_txn( ) async def get_updated_thread_subscriptions_for_user( - self, user_id: str, from_id: int, to_id: int, limit: int + self, user_id: str, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str]]: """Get updates to thread subscriptions for a specific user. diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql new file mode 100644 index 00000000000..03b8a1a6355 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql @@ -0,0 +1,20 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- The maximum stream_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_stream_ordering BIGINT; + +-- The maximum topological_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_topological_ordering BIGINT; diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres new file mode 100644 index 00000000000..fc5d555db54 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS + $$The maximum stream_ordering in the room when the unsubscription was made.$$; + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS + $$The maximum topological_ordering in the room when the unsubscription was made.$$; diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 3b516fce3d9..0ea3a0a4a83 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -73,6 +73,7 @@ from typing_extensions import Self from synapse.appservice.api import ApplicationService + from synapse.events import EventBase from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -1464,3 +1465,31 @@ class ScheduledTask: result: Optional[JsonMapping] # Optional error that should be assigned a value when the status is FAILED error: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class EventOrderings: + stream: int + """ + The stream_ordering of the event. + Negative numbers mean the event was backfilled. + """ + + topological: int + """ + The topological_ordering of the event. + Currently this is equivalent to the `depth` attributes of + the PDU. + """ + + @staticmethod + def from_event(event: "EventBase") -> "EventOrderings": + """ + Get the orderings from an event. + + Preconditions: + - the event must have been persisted (otherwise it won't have a stream ordering) + """ + stream = event.internal_metadata.stream_ordering + assert stream is not None + return EventOrderings(stream, event.depth) diff --git a/synapse/util/pydantic_models.py b/synapse/util/pydantic_models.py index ba9e7bb7d56..4880709501b 100644 --- a/synapse/util/pydantic_models.py +++ b/synapse/util/pydantic_models.py @@ -13,7 +13,11 @@ # # -from synapse._pydantic_compat import BaseModel, Extra +import re +from typing import Any, Callable, Generator + +from synapse._pydantic_compat import BaseModel, Extra, StrictStr +from synapse.types import EventID class ParseModel(BaseModel): @@ -37,3 +41,43 @@ class Config: extra = Extra.ignore # By default, don't allow fields to be reassigned after parsing. allow_mutation = False + + +class AnyEventId(StrictStr): + """ + A validator for strings that need to be an Event ID. + + Accepts any valid grammar of Event ID from any room version. + """ + + EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile( + r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$" + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]: + yield from super().__get_validators__() # type: ignore + yield cls.validate_event_id + + @classmethod + def validate_event_id(cls, value: str) -> str: + if not value.startswith("$"): + raise ValueError("Event ID must start with `$`") + + if ":" in value: + # Room versions 1 and 2 + EventID.from_string(value) # throws on fail + else: + # Room versions 3+: event ID is $ + a base64 sha256 hash + # Room version 3 is base64, 4+ are base64Url + # In both cases, the base64 is unpadded. + # refs: + # - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk + # - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg + b64_hash = value[1:] + if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None: + raise ValueError( + "Event ID must either have a domain part or be a valid hash" + ) + + return value diff --git a/tests/replication/tcp/streams/test_thread_subscriptions.py b/tests/replication/tcp/streams/test_thread_subscriptions.py index 035f0618716..7283aa851e2 100644 --- a/tests/replication/tcp/streams/test_thread_subscriptions.py +++ b/tests/replication/tcp/streams/test_thread_subscriptions.py @@ -62,7 +62,7 @@ def test_thread_subscription_updates(self) -> None: "@test_user:example.org", room_id, thread_root_id, - automatic=True, + automatic_event_orderings=None, ) ) updates.append(thread_root_id) @@ -75,7 +75,7 @@ def test_thread_subscription_updates(self) -> None: "@test_user:example.org", other_room_id, other_thread_root_id, - automatic=False, + automatic_event_orderings=None, ) ) @@ -124,7 +124,7 @@ def test_multiple_users_thread_subscription_updates(self) -> None: for user_id in users: self.get_success( store.subscribe_user_to_thread( - user_id, room_id, thread_root_id, automatic=True + user_id, room_id, thread_root_id, automatic_event_orderings=None ) ) diff --git a/tests/rest/client/test_thread_subscriptions.py b/tests/rest/client/test_thread_subscriptions.py index 624cb9c726e..3fbf3c5bfaa 100644 --- a/tests/rest/client/test_thread_subscriptions.py +++ b/tests/rest/client/test_thread_subscriptions.py @@ -15,6 +15,7 @@ from twisted.internet.testing import MemoryReactor +from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room, thread_subscriptions from synapse.server import HomeServer @@ -49,15 +50,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create a room and send a message to use as a thread root self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) self.helper.join(self.room_id, self.other_user_id, tok=self.other_token) - response = self.helper.send(self.room_id, body="Root message", tok=self.token) - self.root_event_id = response["event_id"] + (self.root_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) # Send a message in the thread - self.helper.send_event( - room_id=self.room_id, - type="m.room.message", - content={ - "body": "Thread message", + self.threaded_events = self.helper.send_messages( + self.room_id, + 2, + content_fn=lambda idx: { + "body": f"Thread message {idx}", "msgtype": "m.text", "m.relates_to": { "rel_type": "m.thread", @@ -106,9 +108,7 @@ def test_subscribe_manual_then_automatic(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - { - "automatic": False, - }, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -127,7 +127,7 @@ def test_subscribe_manual_then_automatic(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {"automatic": self.threaded_events[0]}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -148,11 +148,11 @@ def test_subscribe_automatic_then_manual(self) -> None: "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) # Assert the subscription was saved channel = self.make_request( @@ -167,7 +167,7 @@ def test_subscribe_automatic_then_manual(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": False}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -187,7 +187,7 @@ def test_unsubscribe(self) -> None: "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) @@ -202,7 +202,6 @@ def test_unsubscribe(self) -> None: self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.json_body, {"automatic": True}) - # Now also register a manual subscription channel = self.make_request( "DELETE", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -210,7 +209,6 @@ def test_unsubscribe(self) -> None: ) self.assertEqual(channel.code, HTTPStatus.OK) - # Assert the manual subscription was not overridden channel = self.make_request( "GET", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -224,7 +222,7 @@ def test_set_thread_subscription_nonexistent_thread(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription", - {"automatic": True}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -238,7 +236,7 @@ def test_set_thread_subscription_no_access(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {}, access_token=no_access_token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -249,8 +247,105 @@ def test_invalid_body(self) -> None: channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - # non-boolean `automatic` - {"automatic": "true"}, + # non-Event ID `automatic` + {"automatic": True}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + # non-Event ID `automatic` + {"automatic": "$malformedEventId"}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + def test_auto_subscribe_cause_event_not_in_thread(self) -> None: + """ + Test making an automatic subscription, where the cause event is not + actually in the thread. + This is an error. + """ + (unrelated_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": unrelated_event_id}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body) + self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD) + + def test_auto_resubscription_conflict(self) -> None: + """ + Test that an automatic subscription that conflicts with an unsubscription + is skipped. + """ + # Reuse the test that subscribes and unsubscribes + self.test_unsubscribe() + + # Now no matter which event we present as the cause of an automatic subscription, + # the automatic subscription is skipped. + # This is because the unsubscription happened after all of the events. + for event in self.threaded_events: + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": event, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body) + self.assertEqual( + channel.json_body["errcode"], + Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + channel.text_body, + ) + + # Check the subscription was not made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + + # But if a new event is sent after the unsubscription took place, + # that one can be used for an automatic subscription + (later_event_id,) = self.helper.send_messages( + self.room_id, + 1, + content_fn=lambda _: { + "body": "Thread message after unsubscription", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": self.root_event_id, + }, + }, + tok=self.token, + ) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": later_event_id, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) + + # Check the subscription was made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": True}) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 9f5a5c85a75..129161815f6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -29,12 +29,14 @@ from typing import ( Any, AnyStr, + Callable, Dict, Iterable, Literal, Mapping, MutableMapping, Optional, + Sequence, Tuple, overload, ) @@ -45,7 +47,7 @@ from twisted.internet.testing import MemoryReactorClock from twisted.web.server import Site -from synapse.api.constants import Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -394,6 +396,32 @@ def send( custom_headers=custom_headers, ) + def send_messages( + self, + room_id: str, + num_events: int, + content_fn: Callable[[int], JsonDict] = lambda idx: { + "msgtype": "m.text", + "body": f"Test event {idx}", + }, + tok: Optional[str] = None, + ) -> Sequence[str]: + """ + Helper to send a handful of sequential events and return their event IDs as a sequence. + """ + event_ids = [] + + for event_index in range(num_events): + response = self.send_event( + room_id, + EventTypes.Message, + content_fn(event_index), + tok=tok, + ) + event_ids.append(response["event_id"]) + + return event_ids + def send_event( self, room_id: str, diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 69317d5b0cd..c09a4a9a441 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -12,13 +12,18 @@ # . # -from typing import Optional +from typing import Optional, Union from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscriptionsWorkerStore, +) from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.types import EventOrderings from synapse.util import Clock from tests import unittest @@ -97,10 +102,10 @@ def _subscribe( self, thread_root_id: str, *, - automatic: bool, + automatic_event_orderings: Optional[EventOrderings], room_id: Optional[str] = None, user_id: Optional[str] = None, - ) -> Optional[int]: + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: if user_id is None: user_id = self.user_id @@ -112,7 +117,7 @@ def _subscribe( user_id, room_id, thread_root_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) ) @@ -149,7 +154,7 @@ def test_set_and_get_thread_subscription(self) -> None: # Subscribe self._subscribe( self.thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), ) # Assert subscription went through @@ -164,7 +169,7 @@ def test_set_and_get_thread_subscription(self) -> None: # Now make it a manual subscription self._subscribe( self.thread_root_id, - automatic=False, + automatic_event_orderings=None, ) # Assert the manual subscription overrode the automatic one @@ -178,8 +183,10 @@ def test_set_and_get_thread_subscription(self) -> None: def test_purge_thread_subscriptions_for_user(self) -> None: """Test purging all thread subscription settings for a user.""" # Set subscription settings for multiple threads - self._subscribe(self.thread_root_id, automatic=True) - self._subscribe(self.other_thread_root_id, automatic=False) + self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + self._subscribe(self.other_thread_root_id, automatic_event_orderings=None) subscriptions = self.get_success( self.store.get_updated_thread_subscriptions_for_user( @@ -217,20 +224,32 @@ def test_purge_thread_subscriptions_for_user(self) -> None: def test_get_updated_thread_subscriptions(self) -> None: """Test getting updated thread subscriptions since a stream ID.""" - stream_id1 = self._subscribe(self.thread_root_id, automatic=False) - stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True) - assert stream_id1 is not None - assert stream_id2 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + stream_id2 = self._subscribe( + self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates since initial ID (should include both changes) updates = self.get_success( - self.store.get_updated_thread_subscriptions(0, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=0, to_id=stream_id2, limit=10 + ) ) self.assertEqual(len(updates), 2) # Get updates since first change (should include only the second change) updates = self.get_success( - self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=stream_id1, to_id=stream_id2, limit=10 + ) ) self.assertEqual( updates, @@ -242,21 +261,27 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None: other_user_id = "@other_user:test" # Set thread subscription for main user - stream_id1 = self._subscribe(self.thread_root_id, automatic=True) - assert stream_id1 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) # Set thread subscription for other user stream_id2 = self._subscribe( self.other_thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), user_id=other_user_id, ) - assert stream_id2 is not None + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates for main user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - self.user_id, 0, stream_id2, 10 + self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) @@ -264,9 +289,41 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None: # Get updates for other user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - other_user_id, 0, max(stream_id1, stream_id2), 10 + other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10 ) ) self.assertEqual( updates, [(stream_id2, self.room_id, self.other_thread_root_id)] ) + + def test_should_skip_autosubscription_after_unsubscription(self) -> None: + """ + Tests the comparison logic for whether an autoscription should be skipped + due to a chronologically earlier but logically later unsubscription. + """ + + func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription + + # Order of arguments: + # automatic cause event: stream order, then topological order + # unsubscribe maximums: stream order, then tological order + + # both orderings agree that the unsub is after the cause event + self.assertTrue( + func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2)) + ) + + # topological ordering is inconsistent with stream ordering, + # in that case favour stream ordering because it's what /sync uses + self.assertTrue( + func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1)) + ) + + # the automatic subscription is caused by a backfilled event here + # unfortunately we must fall back to topological ordering here + self.assertTrue( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3)) + ) + self.assertFalse( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1)) + )