Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit bbb3c86

Browse files
authored
Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203)
This is so that we can use it for the backfill events stream.
1 parent 318245e commit bbb3c86

File tree

3 files changed

+134
-11
lines changed

3 files changed

+134
-11
lines changed

changelog.d/8203.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make `MultiWriterIDGenerator` work for streams that use negative values.

synapse/storage/util/id_generators.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
185185
id_column: Column that stores the stream ID.
186186
sequence_name: The name of the postgres sequence used to generate new
187187
IDs.
188+
positive: Whether the IDs are positive (true) or negative (false).
189+
When using negative IDs we go backwards from -1 to -2, -3, etc.
188190
"""
189191

190192
def __init__(
@@ -196,13 +198,19 @@ def __init__(
196198
instance_column: str,
197199
id_column: str,
198200
sequence_name: str,
201+
positive: bool = True,
199202
):
200203
self._db = db
201204
self._instance_name = instance_name
205+
self._positive = positive
206+
self._return_factor = 1 if positive else -1
202207

203208
# We lock as some functions may be called from DB threads.
204209
self._lock = threading.Lock()
205210

211+
# Note: If we are a negative stream then we still store all the IDs as
212+
# positive to make life easier for us, and simply negate the IDs when we
213+
# return them.
206214
self._current_positions = self._load_current_ids(
207215
db_conn, table, instance_column, id_column
208216
)
@@ -233,13 +241,16 @@ def __init__(
233241
def _load_current_ids(
234242
self, db_conn, table: str, instance_column: str, id_column: str
235243
) -> Dict[str, int]:
244+
# If positive stream aggregate via MAX. For negative stream use MIN
245+
# *and* negate the result to get a positive number.
236246
sql = """
237-
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
247+
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
238248
GROUP BY %(instance)s
239249
""" % {
240250
"instance": instance_column,
241251
"id": id_column,
242252
"table": table,
253+
"agg": "MAX" if self._positive else "-MIN",
243254
}
244255

245256
cur = db_conn.cursor()
@@ -269,15 +280,16 @@ async def get_next(self):
269280
# Assert the fetched ID is actually greater than what we currently
270281
# believe the ID to be. If not, then the sequence and table have got
271282
# out of sync somehow.
272-
assert self.get_current_token_for_writer(self._instance_name) < next_id
273-
274283
with self._lock:
284+
assert self._current_positions.get(self._instance_name, 0) < next_id
285+
275286
self._unfinished_ids.add(next_id)
276287

277288
@contextlib.contextmanager
278289
def manager():
279290
try:
280-
yield next_id
291+
# Multiply by the return factor so that the ID has correct sign.
292+
yield self._return_factor * next_id
281293
finally:
282294
self._mark_id_as_finished(next_id)
283295

@@ -296,15 +308,15 @@ async def get_next_mult(self, n: int):
296308
# Assert the fetched ID is actually greater than any ID we've already
297309
# seen. If not, then the sequence and table have got out of sync
298310
# somehow.
299-
assert max(self.get_positions().values(), default=0) < min(next_ids)
300-
301311
with self._lock:
312+
assert max(self._current_positions.values(), default=0) < min(next_ids)
313+
302314
self._unfinished_ids.update(next_ids)
303315

304316
@contextlib.contextmanager
305317
def manager():
306318
try:
307-
yield next_ids
319+
yield [self._return_factor * i for i in next_ids]
308320
finally:
309321
for i in next_ids:
310322
self._mark_id_as_finished(i)
@@ -327,7 +339,7 @@ def get_next_txn(self, txn: LoggingTransaction):
327339
txn.call_after(self._mark_id_as_finished, next_id)
328340
txn.call_on_exception(self._mark_id_as_finished, next_id)
329341

330-
return next_id
342+
return self._return_factor * next_id
331343

332344
def _mark_id_as_finished(self, next_id: int):
333345
"""The ID has finished being processed so we should advance the
@@ -359,20 +371,25 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
359371
"""
360372

361373
with self._lock:
362-
return self._current_positions.get(instance_name, 0)
374+
return self._return_factor * self._current_positions.get(instance_name, 0)
363375

364376
def get_positions(self) -> Dict[str, int]:
365377
"""Get a copy of the current positon map.
366378
"""
367379

368380
with self._lock:
369-
return dict(self._current_positions)
381+
return {
382+
name: self._return_factor * i
383+
for name, i in self._current_positions.items()
384+
}
370385

371386
def advance(self, instance_name: str, new_id: int):
372387
"""Advance the postion of the named writer to the given ID, if greater
373388
than existing entry.
374389
"""
375390

391+
new_id *= self._return_factor
392+
376393
with self._lock:
377394
self._current_positions[instance_name] = max(
378395
new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +407,7 @@ def get_persisted_upto_position(self) -> int:
390407
"""
391408

392409
with self._lock:
393-
return self._persisted_upto_position
410+
return self._return_factor * self._persisted_upto_position
394411

395412
def _add_persisted_position(self, new_id: int):
396413
"""Record that we have persisted a position.

tests/storage/test_id_generators.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,108 @@ def test_get_persisted_upto_position_get_next(self):
264264
# We assume that so long as `get_next` does correctly advance the
265265
# `persisted_upto_position` in this case, then it will be correct in the
266266
# other cases that are tested above (since they'll hit the same code).
267+
268+
269+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
270+
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
271+
"""
272+
273+
if not USE_POSTGRES_FOR_TESTS:
274+
skip = "Requires Postgres"
275+
276+
def prepare(self, reactor, clock, hs):
277+
self.store = hs.get_datastore()
278+
self.db_pool = self.store.db_pool # type: DatabasePool
279+
280+
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
281+
282+
def _setup_db(self, txn):
283+
txn.execute("CREATE SEQUENCE foobar_seq")
284+
txn.execute(
285+
"""
286+
CREATE TABLE foobar (
287+
stream_id BIGINT NOT NULL,
288+
instance_name TEXT NOT NULL,
289+
data TEXT
290+
);
291+
"""
292+
)
293+
294+
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
295+
def _create(conn):
296+
return MultiWriterIdGenerator(
297+
conn,
298+
self.db_pool,
299+
instance_name=instance_name,
300+
table="foobar",
301+
instance_column="instance_name",
302+
id_column="stream_id",
303+
sequence_name="foobar_seq",
304+
positive=False,
305+
)
306+
307+
return self.get_success(self.db_pool.runWithConnection(_create))
308+
309+
def _insert_row(self, instance_name: str, stream_id: int):
310+
"""Insert one row as the given instance with given stream_id.
311+
"""
312+
313+
def _insert(txn):
314+
txn.execute(
315+
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
316+
)
317+
318+
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
319+
320+
def test_single_instance(self):
321+
"""Test that reads and writes from a single process are handled
322+
correctly.
323+
"""
324+
id_gen = self._create_id_generator()
325+
326+
with self.get_success(id_gen.get_next()) as stream_id:
327+
self._insert_row("master", stream_id)
328+
329+
self.assertEqual(id_gen.get_positions(), {"master": -1})
330+
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
331+
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
332+
333+
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
334+
for stream_id in stream_ids:
335+
self._insert_row("master", stream_id)
336+
337+
self.assertEqual(id_gen.get_positions(), {"master": -4})
338+
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
339+
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
340+
341+
# Test loading from DB by creating a second ID gen
342+
second_id_gen = self._create_id_generator()
343+
344+
self.assertEqual(second_id_gen.get_positions(), {"master": -4})
345+
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
346+
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
347+
348+
def test_multiple_instance(self):
349+
"""Tests that having multiple instances that get advanced over
350+
federation works corretly.
351+
"""
352+
id_gen_1 = self._create_id_generator("first")
353+
id_gen_2 = self._create_id_generator("second")
354+
355+
with self.get_success(id_gen_1.get_next()) as stream_id:
356+
self._insert_row("first", stream_id)
357+
id_gen_2.advance("first", stream_id)
358+
359+
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
360+
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
361+
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
362+
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
363+
364+
with self.get_success(id_gen_2.get_next()) as stream_id:
365+
self._insert_row("second", stream_id)
366+
id_gen_1.advance("second", stream_id)
367+
368+
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
369+
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
370+
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
371+
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)

0 commit comments

Comments
 (0)