Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1acc897

Browse files
authored
Implement MSC3816, consider the root event for thread participation. (#12766)
As opposed to only considering a user to have "participated" if they replied to the thread.
1 parent fcd8703 commit 1acc897

File tree

3 files changed

+97
-47
lines changed

3 files changed

+97
-47
lines changed

changelog.d/12766.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.

synapse/handlers/relations.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import (
16-
TYPE_CHECKING,
17-
Collection,
18-
Dict,
19-
FrozenSet,
20-
Iterable,
21-
List,
22-
Optional,
23-
Tuple,
24-
)
15+
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
2516

2617
import attr
2718

@@ -256,13 +247,19 @@ async def get_annotations_for_event(
256247

257248
return filtered_results
258249

259-
async def get_threads_for_events(
260-
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
250+
async def _get_threads_for_events(
251+
self,
252+
events_by_id: Dict[str, EventBase],
253+
relations_by_id: Dict[str, str],
254+
user_id: str,
255+
ignored_users: FrozenSet[str],
261256
) -> Dict[str, _ThreadAggregation]:
262257
"""Get the bundled aggregations for threads for the requested events.
263258
264259
Args:
265-
event_ids: Events to get aggregations for threads.
260+
events_by_id: A map of event_id to events to get aggregations for threads.
261+
relations_by_id: A map of event_id to the relation type, if one exists
262+
for that event.
266263
user_id: The user requesting the bundled aggregations.
267264
ignored_users: The users ignored by the requesting user.
268265
@@ -273,16 +270,34 @@ async def get_threads_for_events(
273270
"""
274271
user = UserID.from_string(user_id)
275272

273+
# It is not valid to start a thread on an event which itself relates to another event.
274+
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
275+
276276
# Fetch thread summaries.
277277
summaries = await self._main_store.get_thread_summaries(event_ids)
278278

279-
# Only fetch participated for a limited selection based on what had
280-
# summaries.
279+
# Limit fetching whether the requester has participated in a thread to
280+
# events which are thread roots.
281281
thread_event_ids = [
282282
event_id for event_id, summary in summaries.items() if summary
283283
]
284-
participated = await self._main_store.get_threads_participated(
285-
thread_event_ids, user_id
284+
285+
# Pre-seed thread participation with whether the requester sent the event.
286+
participated = {
287+
event_id: events_by_id[event_id].sender == user_id
288+
for event_id in thread_event_ids
289+
}
290+
# For events the requester did not send, check the database for whether
291+
# the requester sent a threaded reply.
292+
participated.update(
293+
await self._main_store.get_threads_participated(
294+
[
295+
event_id
296+
for event_id in thread_event_ids
297+
if not participated[event_id]
298+
],
299+
user_id,
300+
)
286301
)
287302

288303
# Then subtract off the results for any ignored users.
@@ -343,7 +358,8 @@ async def get_threads_for_events(
343358
count=thread_count,
344359
# If there's a thread summary it must also exist in the
345360
# participated dictionary.
346-
current_user_participated=participated[event_id],
361+
current_user_participated=events_by_id[event_id].sender == user_id
362+
or participated[event_id],
347363
)
348364

349365
return results
@@ -401,9 +417,9 @@ async def get_bundled_aggregations(
401417
# events to be fetched. Thus, we check those first!
402418

403419
# Fetch thread summaries (but only for the directly requested events).
404-
threads = await self.get_threads_for_events(
405-
# It is not valid to start a thread on an event which itself relates to another event.
406-
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
420+
threads = await self._get_threads_for_events(
421+
events_by_id,
422+
relations_by_id,
407423
user_id,
408424
ignored_users,
409425
)

tests/rest/client/test_relations.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def _test_bundled_aggregations(
896896
relation_type: str,
897897
assertion_callable: Callable[[JsonDict], None],
898898
expected_db_txn_for_event: int,
899+
access_token: Optional[str] = None,
899900
) -> None:
900901
"""
901902
Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ def _test_bundled_aggregations(
907908
for relation-specific assertions.
908909
expected_db_txn_for_event: The number of database transactions which
909910
are expected for a call to /event/.
911+
access_token: The access token to user, defaults to self.user_token.
910912
"""
913+
access_token = access_token or self.user_token
911914

912915
def assert_bundle(event_json: JsonDict) -> None:
913916
"""Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ def assert_bundle(event_json: JsonDict) -> None:
921924
channel = self.make_request(
922925
"GET",
923926
f"/rooms/{self.room}/event/{self.parent_id}",
924-
access_token=self.user_token,
927+
access_token=access_token,
925928
)
926929
self.assertEqual(200, channel.code, channel.json_body)
927930
assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ def assert_bundle(event_json: JsonDict) -> None:
932935
channel = self.make_request(
933936
"GET",
934937
f"/rooms/{self.room}/messages?dir=b",
935-
access_token=self.user_token,
938+
access_token=access_token,
936939
)
937940
self.assertEqual(200, channel.code, channel.json_body)
938941
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,15 +944,15 @@ def assert_bundle(event_json: JsonDict) -> None:
941944
channel = self.make_request(
942945
"GET",
943946
f"/rooms/{self.room}/context/{self.parent_id}",
944-
access_token=self.user_token,
947+
access_token=access_token,
945948
)
946949
self.assertEqual(200, channel.code, channel.json_body)
947950
assert_bundle(channel.json_body["event"])
948951

949952
# Request sync.
950953
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
951954
channel = self.make_request(
952-
"GET", f"/sync?filter={filter}", access_token=self.user_token
955+
"GET", f"/sync?filter={filter}", access_token=access_token
953956
)
954957
self.assertEqual(200, channel.code, channel.json_body)
955958
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ def assert_bundle(event_json: JsonDict) -> None:
962965
"/search",
963966
# Search term matches the parent message.
964967
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
965-
access_token=self.user_token,
968+
access_token=access_token,
966969
)
967970
self.assertEqual(200, channel.code, channel.json_body)
968971
chunk = [
@@ -1037,30 +1040,60 @@ def test_thread(self) -> None:
10371040
"""
10381041
Test that threads get correctly bundled.
10391042
"""
1040-
self._send_relation(RelationTypes.THREAD, "m.room.test")
1041-
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
1043+
# The root message is from "user", send replies as "user2".
1044+
self._send_relation(
1045+
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
1046+
)
1047+
channel = self._send_relation(
1048+
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
1049+
)
10421050
thread_2 = channel.json_body["event_id"]
10431051

1044-
def assert_thread(bundled_aggregations: JsonDict) -> None:
1045-
self.assertEqual(2, bundled_aggregations.get("count"))
1046-
self.assertTrue(bundled_aggregations.get("current_user_participated"))
1047-
# The latest thread event has some fields that don't matter.
1048-
self.assert_dict(
1049-
{
1050-
"content": {
1051-
"m.relates_to": {
1052-
"event_id": self.parent_id,
1053-
"rel_type": RelationTypes.THREAD,
1054-
}
1052+
# This needs two assertion functions which are identical except for whether
1053+
# the current_user_participated flag is True, create a factory for the
1054+
# two versions.
1055+
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
1056+
def assert_thread(bundled_aggregations: JsonDict) -> None:
1057+
self.assertEqual(2, bundled_aggregations.get("count"))
1058+
self.assertEqual(
1059+
participated, bundled_aggregations.get("current_user_participated")
1060+
)
1061+
# The latest thread event has some fields that don't matter.
1062+
self.assert_dict(
1063+
{
1064+
"content": {
1065+
"m.relates_to": {
1066+
"event_id": self.parent_id,
1067+
"rel_type": RelationTypes.THREAD,
1068+
}
1069+
},
1070+
"event_id": thread_2,
1071+
"sender": self.user2_id,
1072+
"type": "m.room.test",
10551073
},
1056-
"event_id": thread_2,
1057-
"sender": self.user_id,
1058-
"type": "m.room.test",
1059-
},
1060-
bundled_aggregations.get("latest_event"),
1061-
)
1074+
bundled_aggregations.get("latest_event"),
1075+
)
10621076

1063-
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
1077+
return assert_thread
1078+
1079+
# The "user" sent the root event and is making queries for the bundled
1080+
# aggregations: they have participated.
1081+
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
1082+
# The "user2" sent replies in the thread and is making queries for the
1083+
# bundled aggregations: they have participated.
1084+
#
1085+
# Note that this re-uses some cached values, so the total number of
1086+
# queries is much smaller.
1087+
self._test_bundled_aggregations(
1088+
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
1089+
)
1090+
1091+
# A user with no interactions with the thread: they have not participated.
1092+
user3_id, user3_token = self._create_user("charlie")
1093+
self.helper.join(self.room, user=user3_id, tok=user3_token)
1094+
self._test_bundled_aggregations(
1095+
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
1096+
)
10641097

10651098
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
10661099
"""
@@ -1106,7 +1139,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
11061139
bundled_aggregations["latest_event"].get("unsigned"),
11071140
)
11081141

1109-
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
1142+
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
11101143

11111144
def test_nested_thread(self) -> None:
11121145
"""

0 commit comments

Comments
 (0)