Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2961006
Update simplified sliding sync docstring
reivilibre Jun 17, 2025
875dbf7
spelling
reivilibre Jul 10, 2025
09f8633
Add models for Thread Subscriptions extension to Sliding Sync
reivilibre Jul 18, 2025
f1f5657
Add overload for `gather_optional_coroutines`/6
reivilibre Jul 18, 2025
748316c
Add thread subscriptions position to `StreamToken`
reivilibre Jul 18, 2025
4dcd12b
Add `subscribed` and `automatic` to `get_updated_thread_subscriptions…
reivilibre Jul 18, 2025
0ce5dce
Fix thread_subscriptions stream sequence
reivilibre Jul 21, 2025
0c310b9
Add comment to MultiWriterIdGenerator about cursed sequence semantics
reivilibre Aug 20, 2025
18881b1
Add overload for `parse_integer_from_args`
reivilibre Aug 20, 2025
4a34641
Implement sliding sync extension part of MSC4308
reivilibre Aug 20, 2025
e72d6cd
Add companion endpoint for backpagination of thread subscriptions
reivilibre Aug 20, 2025
f4cd180
Newsfile
reivilibre Aug 20, 2025
921cd53
Update tests/rest/client/sliding_sync/test_extension_thread_subscript…
reivilibre Sep 2, 2025
168b67b
Update synapse/handlers/sliding_sync/extensions.py
reivilibre Sep 2, 2025
1374895
Update synapse/handlers/sliding_sync/extensions.py
reivilibre Sep 2, 2025
0cf178a
Add notifier hooks for sliding sync
reivilibre Sep 2, 2025
924c1bf
Use copy_and_replace in get_current_token_for_pagination
reivilibre Sep 3, 2025
fa8e3b6
Simplify if
reivilibre Sep 3, 2025
80679a7
Comment on why we still check limit
reivilibre Sep 9, 2025
00cb14e
Merge branch 'develop' into rei/ssext_threadsubs
reivilibre Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ async def get_thread_subscriptions_extension_response(

to_stream_id = to_token.thread_subscriptions_key

updates = await self.store.get_updated_thread_subscriptions_for_user(
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
user_id=sync_config.user.to_string(),
from_id=from_stream_id,
to_id=to_stream_id,
Expand Down
145 changes: 143 additions & 2 deletions synapse/rest/client/thread_subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple

import attr
from typing_extensions import TypeAlias

from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types import (
JsonDict,
RoomID,
SlidingSyncStreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import SlidingSyncResult
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId

if TYPE_CHECKING:
from synapse.server import HomeServer

_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)


class ThreadSubscriptionsRestServlet(RestServlet):
PATTERNS = client_patterns(
Expand Down Expand Up @@ -100,6 +118,129 @@ async def on_DELETE(
return HTTPStatus.OK, {}


class ThreadSubscriptionsPaginationRestServlet(RestServlet):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adding rate-limiting to this handler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... do we rate-limit any read operations already?

Just looking at https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3roomsroomidmessages and sync and I'm not seeing them be rate limited.

Not that it's a terrible idea, but not sure how to balance with consistency (and the fact you could rate limit general requests with a load balancer e.g. nginx for instance)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly asked the question to check that it had been considered (which we tend to forget to do when adding new endpoints).

I'll accept the argument that the endpoint can be rate-limited by IP or access token in the load balancer for now, as I agree that we don't need anything more granular. We also have a MAX_LIMIT value defined, which prevents any one request from drawing too many resources.

PATTERNS = client_patterns(
"/io.element.msc4308/thread_subscriptions$",
unstable=True,
releases=(),
)
CATEGORY = "Thread Subscriptions requests (unstable)"

# Maximum number of thread subscriptions to return in one request.
MAX_LIMIT = 512

def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastores().main

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

limit = min(
parse_integer(request, "limit", default=100, negative=False),
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
)
from_end_opt = parse_string(request, "from", required=False)
to_start_opt = parse_string(request, "to", required=False)
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))

if limit <= 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"limit must be greater than 0",
errcode=Codes.INVALID_PARAM,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

negative=False on parse_integer will already do this for us.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

negative still allows 0 :(

Copy link
Member

@anoadragon453 anoadragon453 Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's annoying.

Could you add a quick comment noting that this conditional is needed because of that?


if from_end_opt is not None:
try:
# because of backwards pagination, the `from` token is actually the
# bound closest to the end of the stream
end_stream_id = ThreadSubscriptionsToken.from_string(
from_end_opt
).stream_id
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`from` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()

if to_start_opt is not None:
# because of backwards pagination, the `to` token is actually the
# bound closest to the start of the stream
try:
start_stream_id = ThreadSubscriptionsToken.from_string(
to_start_opt
).stream_id
except ValueError:
# we also accept sliding sync `pos` tokens on this parameter
try:
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
self.store, to_start_opt
)
start_stream_id = (
sliding_sync_pos.stream_token.thread_subscriptions_key
)
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`to` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
# the start of time is ID 1; the lower bound is exclusive though
start_stream_id = 0

subscriptions = (
await self.store.get_latest_updated_thread_subscriptions_for_user(
requester.user.to_string(),
from_id=start_stream_id,
to_id=end_stream_id,
limit=limit,
)
)

subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
)

result: JsonDict = {}
if subscribed_threads:
result["subscribed"] = subscribed_threads
if unsubscribed_threads:
result["unsubscribed"] = unsubscribed_threads

if len(subscriptions) == limit:
# We hit the limit, so there might be more entries to return.
# Generate a new token that has moved backwards, ready for the next
# request.
min_returned_stream_id, _, _, _, _ = subscriptions[0]
result["end"] = ThreadSubscriptionsToken(
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
stream_id=min_returned_stream_id - 1
).to_string()

return HTTPStatus.OK, result


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc4306_enabled:
ThreadSubscriptionsRestServlet(hs).register(http_server)
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)
20 changes: 14 additions & 6 deletions synapse/storage/databases/main/thread_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,30 +541,38 @@ def get_updated_thread_subscriptions_txn(
get_updated_thread_subscriptions_txn,
)

async def get_updated_thread_subscriptions_for_user(
async def get_latest_updated_thread_subscriptions_for_user(
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
"""Get the latest updates to thread subscriptions for a specific user.

Args:
user_id: The ID of the user
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
If there are too many rows to return, rows from the start (closer to `from_id`)
will be omitted.

Returns:
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
The row with lowest `stream_id` is the first row.
"""

def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
sql = """
WITH the_updates AS (
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT ?
)
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
FROM the_updates
ORDER BY stream_id ASC
LIMIT ?
"""

txn.execute(sql, (user_id, from_id, to_id, limit))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,121 @@ def test_limit_parameter(self) -> None:
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
)

def test_limit_and_companion_backpagination(self) -> None:
"""
Create 1 thread subscription, do a sync, create 4 more,
then sync with a limit of 2 and fill in the gap
using the companion /thread_subscriptions endpoint.
"""

thread_root_ids: List[str] = []

def make_subscription() -> None:
thread_root_resp = self.helper.send(
room_id, body="Some thread root", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])

user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)

# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()

# Make our first subscription
make_subscription()

# Sync for the first time
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
}

sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)

thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
}
},
)

# Get our pos for the next sync
first_sync_pos = sync_resp["pos"]

# Create 4 more thread subsrciptions and subscribe to each
for _ in range(5):
make_subscription()

# Now sync again. Our limit is 2,
# so we should get the latest 2 subscriptions,
# with a gap of 3 more subscriptions in the middle
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)

thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
}
},
)
# 1st backpagination: expecting a page with 2 subscriptions
page, end_tok = self._do_backpaginate(
from_tok=thread_subscriptions["prev_batch"],
to_tok=first_sync_pos,
limit=2,
access_token=user1_tok,
)
self.assertIsNotNone(end_tok, "backpagination should continue")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
}
},
)

# 2nd backpagination: expecting a page with only 1 subscription
# and no other token for further backpagination
assert end_tok is not None
page, end_tok = self._do_backpaginate(
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
)
self.assertIsNone(end_tok, "backpagination should have finished")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
}
},
)

def _do_backpaginate(
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
) -> Tuple[JsonDict, Optional[str]]:
channel = self.make_request(
"GET",
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
access_token=access_token,
)

self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
body = channel.json_body
return body, cast(Optional[str], body.get("end"))

def _subscribe_to_thread(
self, user_id: str, room_id: str, thread_root_id: str
Expand Down
8 changes: 4 additions & 4 deletions tests/storage/test_thread_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_purge_thread_subscriptions_for_user(self) -> None:
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)

subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
Expand All @@ -212,7 +212,7 @@ def test_purge_thread_subscriptions_for_user(self) -> None:

# Check user has no subscriptions
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None:

# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
Expand All @@ -290,7 +290,7 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None:

# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
Expand Down