@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
3737 async def get_relations_for_event (
3838 self ,
3939 event_id : str ,
40+ room_id : str ,
4041 relation_type : Optional [str ] = None ,
4142 event_type : Optional [str ] = None ,
4243 aggregation_key : Optional [str ] = None ,
@@ -49,6 +50,7 @@ async def get_relations_for_event(
4950
5051 Args:
5152 event_id: Fetch events that relate to this event ID.
53+ room_id: The room the event belongs to.
5254 relation_type: Only fetch events with this relation type, if given.
5355 event_type: Only fetch events with this event type, if given.
5456 aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ async def get_relations_for_event(
6365 the form `{"event_id": "..."}`.
6466 """
6567
66- where_clause = ["relates_to_id = ?" ]
67- where_args : List [Union [str , int ]] = [event_id ]
68+ where_clause = ["relates_to_id = ?" , "room_id = ?" ]
69+ where_args : List [Union [str , int ]] = [event_id , room_id ]
6870
6971 if relation_type is not None :
7072 where_clause .append ("relation_type = ?" )
@@ -199,6 +201,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
199201 async def get_aggregation_groups_for_event (
200202 self ,
201203 event_id : str ,
204+ room_id : str ,
202205 event_type : Optional [str ] = None ,
203206 limit : int = 5 ,
204207 direction : str = "b" ,
@@ -213,6 +216,7 @@ async def get_aggregation_groups_for_event(
213216
214217 Args:
215218 event_id: Fetch events that relate to this event ID.
219+ room_id: The room the event belongs to.
216220 event_type: Only fetch events with this event type, if given.
217221 limit: Only fetch the `limit` groups.
218222 direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ async def get_aggregation_groups_for_event(
225229 `type`, `key` and `count` fields.
226230 """
227231
228- where_clause = ["relates_to_id = ?" , "relation_type = ?" ]
229- where_args : List [Union [str , int ]] = [event_id , RelationTypes .ANNOTATION ]
232+ where_clause = ["relates_to_id = ?" , "room_id = ?" , "relation_type = ?" ]
233+ where_args : List [Union [str , int ]] = [
234+ event_id ,
235+ room_id ,
236+ RelationTypes .ANNOTATION ,
237+ ]
230238
231239 if event_type :
232240 where_clause .append ("type = ?" )
@@ -288,14 +296,17 @@ def _get_aggregation_groups_for_event_txn(
288296 )
289297
290298 @cached ()
291- async def get_applicable_edit (self , event_id : str ) -> Optional [EventBase ]:
299+ async def get_applicable_edit (
300+ self , event_id : str , room_id : str
301+ ) -> Optional [EventBase ]:
292302 """Get the most recent edit (if any) that has happened for the given
293303 event.
294304
295305 Correctly handles checking whether edits were allowed to happen.
296306
297307 Args:
298308 event_id: The original event ID
309+ room_id: The original event's room ID
299310
300311 Returns:
301312 The most recent edit, if any.
@@ -317,13 +328,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
317328 WHERE
318329 relates_to_id = ?
319330 AND relation_type = ?
331+ AND edit.room_id = ?
320332 AND edit.type = 'm.room.message'
321333 ORDER by edit.origin_server_ts DESC, edit.event_id DESC
322334 LIMIT 1
323335 """
324336
325337 def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [str ]:
326- txn .execute (sql , (event_id , RelationTypes .REPLACE ))
338+ txn .execute (sql , (event_id , RelationTypes .REPLACE , room_id ))
327339 row = txn .fetchone ()
328340 if row :
329341 return row [0 ]
@@ -340,13 +352,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
340352
341353 @cached ()
342354 async def get_thread_summary (
343- self , event_id : str
355+ self , event_id : str , room_id : str
344356 ) -> Tuple [int , Optional [EventBase ]]:
345357 """Get the number of threaded replies, the senders of those replies, and
346358 the latest reply (if any) for the given event.
347359
348360 Args:
349- event_id: The original event ID
361+ event_id: Summarize the thread related to this event ID.
362+ room_id: The room the event belongs to.
350363
351364 Returns:
352365 The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ def _get_thread_summary_txn(
363376 INNER JOIN events USING (event_id)
364377 WHERE
365378 relates_to_id = ?
379+ AND room_id = ?
366380 AND relation_type = ?
367381 ORDER BY topological_ordering DESC, stream_ordering DESC
368382 LIMIT 1
369383 """
370384
371- txn .execute (sql , (event_id , RelationTypes .THREAD ))
385+ txn .execute (sql , (event_id , room_id , RelationTypes .THREAD ))
372386 row = txn .fetchone ()
373387 if row is None :
374388 return 0 , None
@@ -378,11 +392,13 @@ def _get_thread_summary_txn(
378392 sql = """
379393 SELECT COALESCE(COUNT(event_id), 0)
380394 FROM event_relations
395+ INNER JOIN events USING (event_id)
381396 WHERE
382397 relates_to_id = ?
398+ AND room_id = ?
383399 AND relation_type = ?
384400 """
385- txn .execute (sql , (event_id , RelationTypes .THREAD ))
401+ txn .execute (sql , (event_id , room_id , RelationTypes .THREAD ))
386402 count = txn .fetchone ()[0 ] # type: ignore[index]
387403
388404 return count , latest_event_id
0 commit comments