Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14174.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
44 changes: 42 additions & 2 deletions synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReceiptTypes
from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
Expand Down Expand Up @@ -83,7 +83,7 @@ async def on_POST(
)

# Ensure the event ID roughly correlates to the thread ID.
if thread_id != await self._main_store.get_thread_id(event_id):
if not await self._is_valid_thread_id(event_id, thread_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
Expand All @@ -109,6 +109,46 @@ async def on_POST(

return 200, {}

async def _is_valid_thread_id(self, event_id: str, thread_id: str) -> bool:
"""
The thread ID provided must relate (in a vague sense) to the event ID.

We check this to ensure clients aren't sending bogus receipts.

A thread ID is considered valid if:

1. The event has a thread relation which matches the thread ID.
2. The event has children events which form a thread relation (i.e. the
event is a thread root).
2. The event is related to an event (recursively) which satisfies 1 or 2.

Given the following DAG:

A <---[m.thread]-- B <--[m.annotation]-- C
^
|--[m.reference]-- D <--[m.annotation]-- E

It is valid to send a receipt for thread A on A, B, C, D, or E.

It is valid to send a receipt for the main thread on A, D, and E.

Args:
event_id: The event ID to check.
thread_id: The thread ID the event is potentially part of.

Returns:
True if the event belongs to the given thread.
"""

# If the receipt is on the main timeline, it is enough to check whether
# the event is directly related to a thread.
if thread_id == MAIN_TIMELINE:
return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)

# Otherwise, check if the event is directly part of a thread, or is the
# root message (or related to the root message) of a thread.
return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
1 change: 1 addition & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _invalidate_caches_for_event(
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))

if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
Expand Down
74 changes: 70 additions & 4 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,29 +946,39 @@ async def get_thread_id(self, event_id: str) -> str:
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.

It only searches up the relations tree, i.e. it only considers events
which are the parent of the given event.

See also get_thread_id_for_receipts.

Args:
event_id: The event ID to fetch the thread ID for.

Returns:
The event ID of the root event in the thread, if this event is part
of a thread. "main", otherwise.
"""

# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
SELECT event_id, relates_to_id, relation_type, 0 depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
WHERE depth <= 3
)
SELECT relates_to_id FROM related_events
WHERE relation_type = 'm.thread'
ORDER BY depth DESC
LIMIT 1;
"""

def _get_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
Expand All @@ -978,6 +988,62 @@ def _get_thread_id(txn: LoggingTransaction) -> str:

return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)

@cached()
async def get_thread_id_for_receipts(self, event_id: str) -> str:
"""
Get the thread ID for an event by traversing to the top-most related event
and confirming any children events form a thread.

Given the following DAG:

A <---[m.thread]-- B <--[m.annotation]-- C
^
|--[m.reference]-- D <--[m.annotation]-- E

It considers A, B, C, D, and E as part of the thread.

See also get_thread_id.

Args:
event_id: The event ID to fetch the thread ID for.

Returns:
The event ID of the root event in the thread, if this event is part
of a thread. "main", otherwise.
"""

# Recurse up to the *root* node, then select relations of that to
# see if there are thread children.
sql = """
SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type, 0 depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
WHERE depth <= 3
)
SELECT relates_to_id FROM related_events
ORDER BY depth DESC
LIMIT 1
), ?) AND relation_type = 'm.thread' LIMIT 1;
"""

def _get_related_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id, event_id))
row = txn.fetchone()
if row:
return row[0]

# If no thread was found, it is part of the main timeline.
return MAIN_TIMELINE

return await self.db_pool.runInteraction(
"get_related_thread_id", _get_related_thread_id
)


class RelationsStore(RelationsWorkerStore):
pass
111 changes: 111 additions & 0 deletions tests/storage/test_relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import MAIN_TIMELINE
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest


class RelationsStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
"""
Creates a DAG:
A <---[m.thread]-- B <--[m.annotation]-- C
^
|--[m.reference]-- D <--[m.annotation]-- E
F <--[m.annotation]-- G
"""
self._main_store = self.hs.get_datastores().main

self._create_relation("A", "B", "m.thread")
self._create_relation("B", "C", "m.annotation")
self._create_relation("A", "D", "m.reference")
self._create_relation("A", "E", "m.annotation")
self._create_relation("F", "G", "m.annotation")

def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None:
self.get_success(
self._main_store.db_pool.simple_insert(
table="event_relations",
values={
"event_id": event_id,
"relates_to_id": parent_id,
"relation_type": rel_type,
},
)
)

def test_get_thread_id(self) -> None:
"""
Ensure that get_thread_id only searches up the tree for threads.
"""
# The thread itself and children of it return the thread.
thread_id = self.get_success(self._main_store.get_thread_id("B"))
self.assertEqual("A", thread_id)

thread_id = self.get_success(self._main_store.get_thread_id("C"))
self.assertEqual("A", thread_id)

# But the root and events related to the root do not.
thread_id = self.get_success(self._main_store.get_thread_id("A"))
self.assertEqual(MAIN_TIMELINE, thread_id)

thread_id = self.get_success(self._main_store.get_thread_id("D"))
self.assertEqual(MAIN_TIMELINE, thread_id)

thread_id = self.get_success(self._main_store.get_thread_id("E"))
self.assertEqual(MAIN_TIMELINE, thread_id)

# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)

thread_id = self.get_success(self._main_store.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)

def test_get_thread_id_for_receipts(self) -> None:
"""
Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
"""
# All of the events are considered related to this thread.
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
self.assertEqual("A", thread_id)

thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
self.assertEqual("A", thread_id)

thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
self.assertEqual("A", thread_id)

thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
self.assertEqual("A", thread_id)

thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
self.assertEqual("A", thread_id)

# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)

thread_id = self.get_success(self._main_store.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)