Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 0c1c6a1

Browse files
committed
Improve validation for threads.
1 parent 575b296 commit 0c1c6a1

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

synapse/events/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ async def _injected_bundled_aggregations(
506506
(
507507
thread_count,
508508
latest_thread_event,
509-
) = await self.store.get_thread_summary(event_id)
509+
) = await self.store.get_thread_summary(event_id, room_id)
510510
if latest_thread_event:
511511
aggregations[RelationTypes.THREAD] = {
512512
# Don't bundle aggregations as this could recurse forever.

synapse/storage/databases/main/events.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,9 @@ def _handle_event_relations(
17831783
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
17841784

17851785
if rel_type == RelationTypes.THREAD:
1786-
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
1786+
txn.call_after(
1787+
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
1788+
)
17871789

17881790
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
17891791
"""Handles keeping track of insertion events and edges/connections.

synapse/storage/databases/main/relations.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
348348

349349
@cached()
350350
async def get_thread_summary(
351-
self, event_id: str
351+
self, event_id: str, room_id: str
352352
) -> Tuple[int, Optional[EventBase]]:
353353
"""Get the number of threaded replies, the senders of those replies, and
354354
the latest reply (if any) for the given event.
355355
356356
Args:
357-
event_id: The original event ID
357+
event_id: Summarize the thread related to this event ID.
358+
room_id: The room the event belongs to.
358359
359360
Returns:
360361
The number of items in the thread and the most recent response, if any.
@@ -371,12 +372,13 @@ def _get_thread_summary_txn(
371372
INNER JOIN events USING (event_id)
372373
WHERE
373374
relates_to_id = ?
375+
AND room_id = ?
374376
AND relation_type = ?
375377
ORDER BY topological_ordering DESC, stream_ordering DESC
376378
LIMIT 1
377379
"""
378380

379-
txn.execute(sql, (event_id, RelationTypes.THREAD))
381+
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
380382
row = txn.fetchone()
381383
if row is None:
382384
return 0, None
@@ -386,11 +388,13 @@ def _get_thread_summary_txn(
386388
sql = """
387389
SELECT COALESCE(COUNT(event_id), 0)
388390
FROM event_relations
391+
INNER JOIN events USING (event_id)
389392
WHERE
390393
relates_to_id = ?
394+
AND room_id = ?
391395
AND relation_type = ?
392396
"""
393-
txn.execute(sql, (event_id, RelationTypes.THREAD))
397+
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
394398
count = txn.fetchone()[0] # type: ignore[index]
395399

396400
return count, latest_event_id

tests/rest/client/test_relations.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def test_aggregation_get_event_for_thread(self):
653653
},
654654
)
655655

656+
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
656657
def test_ignore_invalid_room(self):
657658
"""Test that we ignore invalid relations over federation."""
658659
# Create another room and send a message in it.
@@ -664,7 +665,8 @@ def test_ignore_invalid_room(self):
664665
with patch(
665666
"synapse.handlers.message.EventCreationHandler._validate_event_relation"
666667
):
667-
# Generate a reaction and reference relations from a different room.
668+
# Generate a reaction, reference, and thread relations from a
669+
# different room.
668670
self.get_success(
669671
inject_event(
670672
self.hs,
@@ -698,6 +700,23 @@ def test_ignore_invalid_room(self):
698700
)
699701
)
700702

703+
self.get_success(
704+
inject_event(
705+
self.hs,
706+
room_id=self.room,
707+
type="m.room.message",
708+
sender=self.user_id,
709+
content={
710+
"body": "foo",
711+
"msgtype": "m.text",
712+
"m.relates_to": {
713+
"rel_type": RelationTypes.THREAD,
714+
"event_id": parent_id,
715+
},
716+
},
717+
)
718+
)
719+
701720
# They should be ignored when fetching relations.
702721
channel = self.make_request(
703722
"GET",

0 commit comments

Comments
 (0)