Skip to content
This repository was archived by the owner on Mar 26, 2024. It is now read-only.

Commit c71199e

Browse files
authored
Update all stream IDs after processing replication rows (matrix-org#14723) (#52)
* Update all stream IDs after processing replication rows (matrix-org#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: matrix-org#14158 (comment) # Conflicts: # synapse/storage/databases/main/devices.py # synapse/storage/databases/main/events_worker.py * Fix bad cherry-picking * Remove leftover stream advance
1 parent 90878d6 commit c71199e

File tree

15 files changed

+115
-66
lines changed

15 files changed

+115
-66
lines changed

changelog.d/14723.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).

synapse/replication/tcp/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ async def on_rdata(
148148
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
149149
"""
150150
self.store.process_replication_rows(stream_name, instance_name, token, rows)
151+
# NOTE: this must be called after process_replication_rows to ensure any
152+
# cache invalidations are first handled before any stream ID advances.
153+
self.store.process_replication_position(stream_name, instance_name, token)
151154

152155
if self.send_handler:
153156
await self.send_handler.process_replication_rows(stream_name, token, rows)

synapse/storage/_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,22 @@ def process_replication_rows( # noqa: B027 (no-op by design)
5959
token: int,
6060
rows: Iterable[Any],
6161
) -> None:
62-
pass
62+
"""
63+
Used by storage classes to invalidate caches based on incoming replication data. These
64+
must not update any ID generators, use `process_replication_position`.
65+
"""
66+
67+
def process_replication_position( # noqa: B027 (no-op by design)
68+
self,
69+
stream_name: str,
70+
instance_name: str,
71+
token: int,
72+
) -> None:
73+
"""
74+
Used by storage classes to advance ID generators based on incoming replication data. This
75+
is called after process_replication_rows such that caches are invalidated before any token
76+
positions advance.
77+
"""
6378

6479
def _invalidate_state_caches(
6580
self, room_id: str, members_changed: Collection[str]

synapse/storage/databases/main/account_data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,7 @@ def process_replication_rows(
415415
token: int,
416416
rows: Iterable[Any],
417417
) -> None:
418-
if stream_name == TagAccountDataStream.NAME:
419-
self._account_data_id_gen.advance(instance_name, token)
420-
elif stream_name == AccountDataStream.NAME:
421-
self._account_data_id_gen.advance(instance_name, token)
418+
if stream_name == AccountDataStream.NAME:
422419
for row in rows:
423420
if not row.room_id:
424421
self.get_global_account_data_by_type_for_user.invalidate(
@@ -433,6 +430,15 @@ def process_replication_rows(
433430

434431
super().process_replication_rows(stream_name, instance_name, token, rows)
435432

433+
def process_replication_position(
434+
self, stream_name: str, instance_name: str, token: int
435+
) -> None:
436+
if stream_name == TagAccountDataStream.NAME:
437+
self._account_data_id_gen.advance(instance_name, token)
438+
elif stream_name == AccountDataStream.NAME:
439+
self._account_data_id_gen.advance(instance_name, token)
440+
super().process_replication_position(stream_name, instance_name, token)
441+
436442
async def add_account_data_to_room(
437443
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
438444
) -> int:

synapse/storage/databases/main/cache.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ def process_replication_rows(
164164
backfilled=True,
165165
)
166166
elif stream_name == CachesStream.NAME:
167-
if self._cache_id_gen:
168-
self._cache_id_gen.advance(instance_name, token)
169-
170167
for row in rows:
171168
if row.cache_func == CURRENT_STATE_CACHE_NAME:
172169
if row.keys is None:
@@ -182,6 +179,14 @@ def process_replication_rows(
182179

183180
super().process_replication_rows(stream_name, instance_name, token, rows)
184181

182+
def process_replication_position(
183+
self, stream_name: str, instance_name: str, token: int
184+
) -> None:
185+
if stream_name == CachesStream.NAME:
186+
if self._cache_id_gen:
187+
self._cache_id_gen.advance(instance_name, token)
188+
super().process_replication_position(stream_name, instance_name, token)
189+
185190
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
186191
data = row.data
187192

@@ -198,8 +203,14 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
198203
backfilled=False,
199204
)
200205
elif row.type == EventsStreamCurrentStateRow.TypeId:
201-
# TODO: Nothing to do here, handled in events_worker, cleanup?
202-
pass
206+
assert isinstance(data, EventsStreamCurrentStateRow)
207+
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
208+
209+
if data.type == EventTypes.Member:
210+
self.get_rooms_for_user_with_stream_ordering.invalidate(
211+
(data.state_key,)
212+
)
213+
self.get_rooms_for_user.invalidate((data.state_key,))
203214
else:
204215
raise Exception("Unknown events stream row type %s" % (row.type,))
205216

synapse/storage/databases/main/deviceinbox.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,15 @@ def process_replication_rows(
160160
self._device_federation_outbox_stream_cache.entity_has_changed(
161161
row.entity, token
162162
)
163-
# Important that the ID gen advances after stream change caches
164-
self._device_inbox_id_gen.advance(instance_name, token)
165163
return super().process_replication_rows(stream_name, instance_name, token, rows)
166164

165+
def process_replication_position(
166+
self, stream_name: str, instance_name: str, token: int
167+
) -> None:
168+
if stream_name == ToDeviceStream.NAME:
169+
self._device_inbox_id_gen.advance(instance_name, token)
170+
super().process_replication_position(stream_name, instance_name, token)
171+
167172
def get_to_device_stream_token(self) -> int:
168173
return self._device_inbox_id_gen.get_current_token()
169174

synapse/storage/databases/main/devices.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,20 @@ def process_replication_rows(
163163
) -> None:
164164
if stream_name == DeviceListsStream.NAME:
165165
self._invalidate_caches_for_devices(token, rows)
166-
# Important that the ID gen advances after stream change caches
167-
self._device_list_id_gen.advance(instance_name, token)
168166
elif stream_name == UserSignatureStream.NAME:
169167
for row in rows:
170168
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
171-
# Important that the ID gen advances after stream change caches
172-
self._device_list_id_gen.advance(instance_name, token)
173169
return super().process_replication_rows(stream_name, instance_name, token, rows)
174170

171+
def process_replication_position(
172+
self, stream_name: str, instance_name: str, token: int
173+
) -> None:
174+
if stream_name == DeviceListsStream.NAME:
175+
self._device_list_id_gen.advance(instance_name, token)
176+
elif stream_name == UserSignatureStream.NAME:
177+
self._device_list_id_gen.advance(instance_name, token)
178+
super().process_replication_position(stream_name, instance_name, token)
179+
175180
def _invalidate_caches_for_devices(
176181
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
177182
) -> None:

synapse/storage/databases/main/event_federation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ async def get_forward_extremities_for_room_at_stream_ordering(
11871187
"""
11881188
# We want to make the cache more effective, so we clamp to the last
11891189
# change before the given ordering.
1190-
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
1190+
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
11911191

11921192
# We don't always have a full stream_to_exterm_id table, e.g. after
11931193
# the upgrade that introduced it, so we make sure we never ask for a

synapse/storage/databases/main/events_worker.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,6 @@ def __init__(
249249
prefilled_cache=curr_state_delta_prefill,
250250
)
251251

252-
event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
253-
db_conn,
254-
"events",
255-
entity_column="room_id",
256-
stream_column="stream_ordering",
257-
max_value=events_max,
258-
)
259-
self._events_stream_cache = StreamChangeCache(
260-
"EventsRoomStreamChangeCache",
261-
min_event_val,
262-
prefilled_cache=event_cache_prefill,
263-
)
264-
self._membership_stream_cache = StreamChangeCache(
265-
"MembershipStreamChangeCache", events_max
266-
)
267-
268252
if hs.config.worker.run_background_tasks:
269253
# We periodically clean out old transaction ID mappings
270254
self._clock.looping_call(
@@ -325,35 +309,14 @@ def get_chain_id_txn(txn: Cursor) -> int:
325309
id_column="chain_id",
326310
)
327311

328-
def process_replication_rows(
329-
self,
330-
stream_name: str,
331-
instance_name: str,
332-
token: int,
333-
rows: Iterable[Any],
312+
def process_replication_position(
313+
self, stream_name: str, instance_name: str, token: int
334314
) -> None:
335-
# Process event stream replication rows, handling both the ID generators from the events
336-
# worker store and the stream change caches in this store as the two are interlinked.
337315
if stream_name == EventsStream.NAME:
338-
for row in rows:
339-
if row.type == EventsStreamEventRow.TypeId:
340-
self._events_stream_cache.entity_has_changed(
341-
row.data.room_id, token
342-
)
343-
if row.data.type == EventTypes.Member:
344-
self._membership_stream_cache.entity_has_changed(
345-
row.data.state_key, token
346-
)
347-
if row.type == EventsStreamCurrentStateRow.TypeId:
348-
self._curr_state_delta_stream_cache.entity_has_changed(
349-
row.data.room_id, token
350-
)
351-
# Important that the ID gen advances after stream change caches
352316
self._stream_id_gen.advance(instance_name, token)
353317
elif stream_name == BackfillStream.NAME:
354318
self._backfill_id_gen.advance(instance_name, -token)
355-
356-
super().process_replication_rows(stream_name, instance_name, token, rows)
319+
super().process_replication_position(stream_name, instance_name, token)
357320

358321
async def have_censored_event(self, event_id: str) -> bool:
359322
"""Check if an event has been censored, i.e. if the content of the event has been erased

synapse/storage/databases/main/presence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,14 @@ def process_replication_rows(
439439
rows: Iterable[Any],
440440
) -> None:
441441
if stream_name == PresenceStream.NAME:
442-
self._presence_id_gen.advance(instance_name, token)
443442
for row in rows:
444443
self.presence_stream_cache.entity_has_changed(row.user_id, token)
445444
self._get_presence_for_user.invalidate((row.user_id,))
446445
return super().process_replication_rows(stream_name, instance_name, token, rows)
446+
447+
def process_replication_position(
448+
self, stream_name: str, instance_name: str, token: int
449+
) -> None:
450+
if stream_name == PresenceStream.NAME:
451+
self._presence_id_gen.advance(instance_name, token)
452+
super().process_replication_position(stream_name, instance_name, token)

0 commit comments

Comments
 (0)