2222import attr
2323from typing_extensions import Deque
2424
25+ from synapse .metrics .background_process_metrics import run_as_background_process
2526from synapse .storage .database import DatabasePool , LoggingTransaction
2627from synapse .storage .util .sequence import PostgresSequenceGenerator
2728
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
184185 Args:
185186 db_conn
186187 db
188+ stream_name: A name for the stream.
187189 instance_name: The name of this instance.
188190 table: Database table associated with stream.
189191 instance_column: Column that stores the row's writer's instance name
190192 id_column: Column that stores the stream ID.
191193 sequence_name: The name of the postgres sequence used to generate new
192194 IDs.
195+ writers: A list of known writers to use to populate current positions
196+ on startup. Can be empty if nothing uses `get_current_token` or
197+ `get_positions` (e.g. caches stream).
193198 positive: Whether the IDs are positive (true) or negative (false).
194199 When using negative IDs we go backwards from -1 to -2, -3, etc.
195200 """
@@ -198,16 +203,20 @@ def __init__(
198203 self ,
199204 db_conn ,
200205 db : DatabasePool ,
206+ stream_name : str ,
201207 instance_name : str ,
202208 table : str ,
203209 instance_column : str ,
204210 id_column : str ,
205211 sequence_name : str ,
212+ writers : List [str ],
206213 positive : bool = True ,
207214 ):
208215 self ._db = db
216+ self ._stream_name = stream_name
209217 self ._instance_name = instance_name
210218 self ._positive = positive
219+ self ._writers = writers
211220 self ._return_factor = 1 if positive else - 1
212221
213222 # We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ def __init__(
216225 # Note: If we are a negative stream then we still store all the IDs as
217226 # positive to make life easier for us, and simply negate the IDs when we
218227 # return them.
219- self ._current_positions = self ._load_current_ids (
220- db_conn , table , instance_column , id_column
221- )
228+ self ._current_positions = {} # type: Dict[str, int]
222229
223230 # Set of local IDs that we're still processing. The current position
224231 # should be less than the minimum of this set (if not empty).
@@ -251,30 +258,80 @@ def __init__(
251258
252259 self ._sequence_gen = PostgresSequenceGenerator (sequence_name )
253260
261+ # This goes and fills out the above state from the database.
262+ self ._load_current_ids (db_conn , table , instance_column , id_column )
263+
254264 def _load_current_ids (
255265 self , db_conn , table : str , instance_column : str , id_column : str
256- ) -> Dict [str , int ]:
257- # If positive stream aggregate via MAX. For negative stream use MIN
258- # *and* negate the result to get a positive number.
259- sql = """
260- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
261- GROUP BY %(instance)s
262- """ % {
263- "instance" : instance_column ,
264- "id" : id_column ,
265- "table" : table ,
266- "agg" : "MAX" if self ._positive else "-MIN" ,
267- }
268-
266+ ):
269267 cur = db_conn .cursor ()
270- cur .execute (sql )
271268
272- # `cur` is an iterable over returned rows, which are 2-tuples.
273- current_positions = dict (cur )
269+ # Load the current positions of all writers for the stream.
270+ if self ._writers :
271+ sql = """
272+ SELECT instance_name, stream_id FROM stream_positions
273+ WHERE stream_name = ?
274+ """
275+ sql = self ._db .engine .convert_param_style (sql )
274276
275- cur .close ()
277+ cur .execute (sql , (self ._stream_name ,))
278+
279+ self ._current_positions = {
280+ instance : stream_id * self ._return_factor
281+ for instance , stream_id in cur
282+ if instance in self ._writers
283+ }
284+
285+ # We set the `_persisted_upto_position` to be the minimum of all current
286+ # positions. If empty we use the max stream ID from the DB table.
287+ min_stream_id = min (self ._current_positions .values (), default = None )
288+
289+ if min_stream_id is None :
290+ sql = """
291+ SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s
292+ """ % {
293+ "id" : id_column ,
294+ "table" : table ,
295+ "agg" : "MAX" if self ._positive else "-MIN" ,
296+ }
297+ cur .execute (sql )
298+ (stream_id ,) = cur .fetchone ()
299+ self ._persisted_upto_position = stream_id
300+ else :
301+ # If we have a min_stream_id then we pull out everything greater
302+ # than it from the DB so that we can prefill
303+ # `_known_persisted_positions` and get a more accurate
304+ # `_persisted_upto_position`.
305+ #
306+ # We also check if any of the later rows are from this instance, in
307+ # which case we use that for this instance's current position. This
308+ # is to handle the case where we didn't finish persisting to the
309+ # stream positions table before restart (or the stream position
310+ # table otherwise got out of date).
311+
312+ sql = """
313+ SELECT %(instance)s, %(id)s FROM %(table)s
314+ WHERE ? %(cmp)s %(id)s
315+ """ % {
316+ "id" : id_column ,
317+ "table" : table ,
318+ "instance" : instance_column ,
319+ "cmp" : "<=" if self ._positive else ">=" ,
320+ }
321+ sql = self ._db .engine .convert_param_style (sql )
322+ cur .execute (sql , (min_stream_id ,))
323+
324+ self ._persisted_upto_position = min_stream_id
325+
326+ with self ._lock :
327+ for (instance , stream_id ,) in cur :
328+ stream_id = self ._return_factor * stream_id
329+ self ._add_persisted_position (stream_id )
276330
277- return current_positions
331+ if instance == self ._instance_name :
332+ self ._current_positions [instance ] = stream_id
333+
334+ cur .close ()
278335
279336 def _load_next_id_txn (self , txn ) -> int :
280337 return self ._sequence_gen .get_next_id_txn (txn )
@@ -316,6 +373,21 @@ def get_next_txn(self, txn: LoggingTransaction):
316373 txn .call_after (self ._mark_id_as_finished , next_id )
317374 txn .call_on_exception (self ._mark_id_as_finished , next_id )
318375
376+ # Update the `stream_positions` table with newly updated stream
377+ # ID (unless self._writers is not set in which case we don't
378+ # bother, as nothing will read it).
379+ #
380+ # We only do this on the success path so that the persisted current
381+ # position points to a persited row with the correct instance name.
382+ if self ._writers :
383+ txn .call_after (
384+ run_as_background_process ,
385+ "MultiWriterIdGenerator._update_table" ,
386+ self ._db .runInteraction ,
387+ "MultiWriterIdGenerator._update_table" ,
388+ self ._update_stream_positions_table_txn ,
389+ )
390+
319391 return self ._return_factor * next_id
320392
321393 def _mark_id_as_finished (self , next_id : int ):
@@ -447,6 +519,28 @@ def _add_persisted_position(self, new_id: int):
447519 # do.
448520 break
449521
522+ def _update_stream_positions_table_txn (self , txn ):
523+ """Update the `stream_positions` table with newly persisted position.
524+ """
525+
526+ if not self ._writers :
527+ return
528+
529+ # We upsert the value, ensuring on conflict that we always increase the
530+ # value (or decrease if stream goes backwards).
531+ sql = """
532+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
533+ VALUES (?, ?, ?)
534+ ON CONFLICT (stream_name, instance_name)
535+ DO UPDATE SET
536+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
537+ """ % {
538+ "agg" : "GREATEST" if self ._positive else "LEAST" ,
539+ }
540+
541+ pos = (self .get_current_token_for_writer (self ._instance_name ),)
542+ txn .execute (sql , (self ._stream_name , self ._instance_name , pos ))
543+
450544
451545@attr .s (slots = True )
452546class _AsyncCtxManagerWrapper :
@@ -503,4 +597,16 @@ async def __aexit__(self, exc_type, exc, tb):
503597 if exc_type is not None :
504598 return False
505599
600+ # Update the `stream_positions` table with newly updated stream
601+ # ID (unless self._writers is not set in which case we don't
602+ # bother, as nothing will read it).
603+ #
604+ # We only do this on the success path so that the persisted current
605+ # position points to a persisted row with the correct instance name.
606+ if self .id_gen ._writers :
607+ await self .id_gen ._db .runInteraction (
608+ "MultiWriterIdGenerator._update_table" ,
609+ self .id_gen ._update_stream_positions_table_txn ,
610+ )
611+
506612 return False
0 commit comments