9898)
9999from synapse .storage .databases .main .receipts import ReceiptsWorkerStore
100100from synapse .storage .databases .main .stream import StreamWorkerStore
101+ from synapse .types import JsonDict
101102from synapse .util import json_encoder
102103from synapse .util .caches .descriptors import cached
103104
@@ -232,6 +233,104 @@ def __init__(
232233 replaces_index = "event_push_summary_user_rm" ,
233234 )
234235
236+ self .db_pool .updates .register_background_index_update (
237+ "event_push_summary_unique_index2" ,
238+ index_name = "event_push_summary_unique_index2" ,
239+ table = "event_push_summary" ,
240+ columns = ["user_id" , "room_id" , "thread_id" ],
241+ unique = True ,
242+ )
243+
244+ self .db_pool .updates .register_background_update_handler (
245+ "event_push_backfill_thread_id" ,
246+ self ._background_backfill_thread_id ,
247+ )
248+
249+ async def _background_backfill_thread_id (
250+ self , progress : JsonDict , batch_size : int
251+ ) -> int :
252+ """
253+ Fill in the thread_id field for event_push_actions and event_push_summary.
254+
255+ This is preparatory so that it can be made non-nullable in the future.
256+
257+ Because all current (null) data is done in an unthreaded manner this
258+ simply assumes it is on the "main" timeline. Since event_push_actions
259+ are periodically cleared it is not possible to correctly re-calculate
260+ the thread_id.
261+ """
262+ event_push_actions_done = progress .get ("event_push_actions_done" , False )
263+
264+ def add_thread_id_txn (
265+ txn : LoggingTransaction , table_name : str , start_stream_ordering : int
266+ ) -> int :
267+ sql = f"""
268+ SELECT stream_ordering
269+ FROM { table_name }
270+ WHERE
271+ thread_id IS NULL
272+ AND stream_ordering > ?
273+ ORDER BY stream_ordering
274+ LIMIT ?
275+ """
276+ txn .execute (sql , (start_stream_ordering , batch_size ))
277+
278+ # No more rows to process.
279+ rows = txn .fetchall ()
280+ if not rows :
281+ progress [f"{ table_name } _done" ] = True
282+ self .db_pool .updates ._background_update_progress_txn (
283+ txn , "event_push_backfill_thread_id" , progress
284+ )
285+ return 0
286+
287+ # Update the thread ID for any of those rows.
288+ max_stream_ordering = rows [- 1 ][0 ]
289+
290+ sql = f"""
291+ UPDATE { table_name }
292+ SET thread_id = 'main'
293+ WHERE stream_ordering <= ? AND thread_id IS NULL
294+ """
295+ txn .execute (sql , (max_stream_ordering ,))
296+
297+ # Update progress.
298+ processed_rows = txn .rowcount
299+ progress [f"max_{ table_name } _stream_ordering" ] = max_stream_ordering
300+ self .db_pool .updates ._background_update_progress_txn (
301+ txn , "event_push_backfill_thread_id" , progress
302+ )
303+
304+ return processed_rows
305+
306+ # First update the event_push_actions table, then the event_push_summary table.
307+ #
308+ # Note that the event_push_actions_staging table is ignored since it is
309+ # assumed that items in that table will only exist for a short period of
310+ # time.
311+ if not event_push_actions_done :
312+ result = await self .db_pool .runInteraction (
313+ "event_push_backfill_thread_id" ,
314+ add_thread_id_txn ,
315+ "event_push_actions" ,
316+ progress .get ("max_event_push_actions_stream_ordering" , 0 ),
317+ )
318+ else :
319+ result = await self .db_pool .runInteraction (
320+ "event_push_backfill_thread_id" ,
321+ add_thread_id_txn ,
322+ "event_push_summary" ,
323+ progress .get ("max_event_push_summary_stream_ordering" , 0 ),
324+ )
325+
326+ # Only done after the event_push_summary table is done.
327+ if not result :
328+ await self .db_pool .updates ._end_background_update (
329+ "event_push_backfill_thread_id"
330+ )
331+
332+ return result
333+
235334 @cached (tree = True , max_entries = 5000 )
236335 async def get_unread_event_push_actions_by_room_for_user (
237336 self ,
@@ -670,6 +769,7 @@ async def add_push_actions_to_staging(
670769 event_id : str ,
671770 user_id_actions : Dict [str , Collection [Union [Mapping , str ]]],
672771 count_as_unread : bool ,
772+ thread_id : str ,
673773 ) -> None :
674774 """Add the push actions for the event to the push action staging area.
675775
@@ -678,6 +778,7 @@ async def add_push_actions_to_staging(
678778 user_id_actions: A mapping of user_id to list of push actions, where
679779 an action can either be a string or dict.
680780 count_as_unread: Whether this event should increment unread counts.
781+ thread_id: The thread this event is parent of, if applicable.
681782 """
682783 if not user_id_actions :
683784 return
@@ -686,7 +787,7 @@ async def add_push_actions_to_staging(
686787 # can be used to insert into the `event_push_actions_staging` table.
687788 def _gen_entry (
688789 user_id : str , actions : Collection [Union [Mapping , str ]]
689- ) -> Tuple [str , str , str , int , int , int ]:
790+ ) -> Tuple [str , str , str , int , int , int , str ]:
690791 is_highlight = 1 if _action_has_highlight (actions ) else 0
691792 notif = 1 if "notify" in actions else 0
692793 return (
@@ -696,11 +797,20 @@ def _gen_entry(
696797 notif , # notif column
697798 is_highlight , # highlight column
698799 int (count_as_unread ), # unread column
800+ thread_id , # thread_id column
699801 )
700802
701803 await self .db_pool .simple_insert_many (
702804 "event_push_actions_staging" ,
703- keys = ("event_id" , "user_id" , "actions" , "notif" , "highlight" , "unread" ),
805+ keys = (
806+ "event_id" ,
807+ "user_id" ,
808+ "actions" ,
809+ "notif" ,
810+ "highlight" ,
811+ "unread" ,
812+ "thread_id" ,
813+ ),
704814 values = [
705815 _gen_entry (user_id , actions )
706816 for user_id , actions in user_id_actions .items ()
@@ -981,6 +1091,8 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
9811091 )
9821092
9831093 # Replace the previous summary with the new counts.
1094+ #
1095+ # TODO(threads): Upsert per-thread instead of setting them all to main.
9841096 self .db_pool .simple_upsert_txn (
9851097 txn ,
9861098 table = "event_push_summary" ,
@@ -990,6 +1102,7 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
9901102 "unread_count" : unread_count ,
9911103 "stream_ordering" : old_rotate_stream_ordering ,
9921104 "last_receipt_stream_ordering" : stream_ordering ,
1105+ "thread_id" : "main" ,
9931106 },
9941107 )
9951108
@@ -1138,17 +1251,19 @@ def _rotate_notifs_before_txn(
11381251
11391252 logger .info ("Rotating notifications, handling %d rows" , len (summaries ))
11401253
1254+ # TODO(threads): Update on a per-thread basis.
11411255 self .db_pool .simple_upsert_many_txn (
11421256 txn ,
11431257 table = "event_push_summary" ,
11441258 key_names = ("user_id" , "room_id" ),
11451259 key_values = [(user_id , room_id ) for user_id , room_id in summaries ],
1146- value_names = ("notif_count" , "unread_count" , "stream_ordering" ),
1260+ value_names = ("notif_count" , "unread_count" , "stream_ordering" , "thread_id" ),
11471261 value_values = [
11481262 (
11491263 summary .notif_count ,
11501264 summary .unread_count ,
11511265 summary .stream_ordering ,
1266+ "main" ,
11521267 )
11531268 for summary in summaries .values ()
11541269 ],
0 commit comments