@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
185185 id_column: Column that stores the stream ID.
186186 sequence_name: The name of the postgres sequence used to generate new
187187 IDs.
188+ positive: Whether the IDs are positive (true) or negative (false).
189+ When using negative IDs we go backwards from -1 to -2, -3, etc.
188190 """
189191
190192 def __init__ (
@@ -196,13 +198,19 @@ def __init__(
196198 instance_column : str ,
197199 id_column : str ,
198200 sequence_name : str ,
201+ positive : bool = True ,
199202 ):
200203 self ._db = db
201204 self ._instance_name = instance_name
205+ self ._positive = positive
206+ self ._return_factor = 1 if positive else - 1
202207
203208 # We lock as some functions may be called from DB threads.
204209 self ._lock = threading .Lock ()
205210
211+ # Note: If we are a negative stream then we still store all the IDs as
212+ # positive to make life easier for us, and simply negate the IDs when we
213+ # return them.
206214 self ._current_positions = self ._load_current_ids (
207215 db_conn , table , instance_column , id_column
208216 )
@@ -233,13 +241,16 @@ def __init__(
233241 def _load_current_ids (
234242 self , db_conn , table : str , instance_column : str , id_column : str
235243 ) -> Dict [str , int ]:
244+ # If positive stream aggregate via MAX. For negative stream use MIN
245+ # *and* negate the result to get a positive number.
236246 sql = """
237- SELECT %(instance)s, MAX (%(id)s) FROM %(table)s
247+ SELECT %(instance)s, %(agg)s (%(id)s) FROM %(table)s
238248 GROUP BY %(instance)s
239249 """ % {
240250 "instance" : instance_column ,
241251 "id" : id_column ,
242252 "table" : table ,
253+ "agg" : "MAX" if self ._positive else "-MIN" ,
243254 }
244255
245256 cur = db_conn .cursor ()
@@ -269,15 +280,16 @@ async def get_next(self):
269280 # Assert the fetched ID is actually greater than what we currently
270281 # believe the ID to be. If not, then the sequence and table have got
271282 # out of sync somehow.
272- assert self .get_current_token_for_writer (self ._instance_name ) < next_id
273-
274283 with self ._lock :
284+ assert self ._current_positions .get (self ._instance_name , 0 ) < next_id
285+
275286 self ._unfinished_ids .add (next_id )
276287
277288 @contextlib .contextmanager
278289 def manager ():
279290 try :
280- yield next_id
291+ # Multiply by the return factor so that the ID has correct sign.
292+ yield self ._return_factor * next_id
281293 finally :
282294 self ._mark_id_as_finished (next_id )
283295
@@ -296,15 +308,15 @@ async def get_next_mult(self, n: int):
296308 # Assert the fetched ID is actually greater than any ID we've already
297309 # seen. If not, then the sequence and table have got out of sync
298310 # somehow.
299- assert max (self .get_positions ().values (), default = 0 ) < min (next_ids )
300-
301311 with self ._lock :
312+ assert max (self ._current_positions .values (), default = 0 ) < min (next_ids )
313+
302314 self ._unfinished_ids .update (next_ids )
303315
304316 @contextlib .contextmanager
305317 def manager ():
306318 try :
307- yield next_ids
319+ yield [ self . _return_factor * i for i in next_ids ]
308320 finally :
309321 for i in next_ids :
310322 self ._mark_id_as_finished (i )
@@ -327,7 +339,7 @@ def get_next_txn(self, txn: LoggingTransaction):
327339 txn .call_after (self ._mark_id_as_finished , next_id )
328340 txn .call_on_exception (self ._mark_id_as_finished , next_id )
329341
330- return next_id
342+ return self . _return_factor * next_id
331343
332344 def _mark_id_as_finished (self , next_id : int ):
333345 """The ID has finished being processed so we should advance the
@@ -359,20 +371,25 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
359371 """
360372
361373 with self ._lock :
362- return self ._current_positions .get (instance_name , 0 )
374+ return self ._return_factor * self . _current_positions .get (instance_name , 0 )
363375
364376 def get_positions (self ) -> Dict [str , int ]:
365377 """Get a copy of the current positon map.
366378 """
367379
368380 with self ._lock :
369- return dict (self ._current_positions )
381+ return {
382+ name : self ._return_factor * i
383+ for name , i in self ._current_positions .items ()
384+ }
370385
371386 def advance (self , instance_name : str , new_id : int ):
372387 """Advance the postion of the named writer to the given ID, if greater
373388 than existing entry.
374389 """
375390
391+ new_id *= self ._return_factor
392+
376393 with self ._lock :
377394 self ._current_positions [instance_name ] = max (
378395 new_id , self ._current_positions .get (instance_name , 0 )
@@ -390,7 +407,7 @@ def get_persisted_upto_position(self) -> int:
390407 """
391408
392409 with self ._lock :
393- return self ._persisted_upto_position
410+ return self ._return_factor * self . _persisted_upto_position
394411
395412 def _add_persisted_position (self , new_id : int ):
396413 """Record that we have persisted a position.
0 commit comments