Skip to content

Commit a9368d0

Browse files
committed
wip: indexer recover from connection errors
1 parent 8455a7e commit a9368d0

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

share/search/daemon.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import contextlib
22
import collections
3+
from collections.abc import Callable
34
import dataclasses
45
import logging
56
import queue
67
import random
78
import threading
89
import time
910

11+
import amqp.exceptions
1012
from django.conf import settings
13+
import kombu
1114
from kombu.mixins import ConsumerMixin
1215
import sentry_sdk
1316

@@ -27,6 +30,7 @@
2730
MINIMUM_BACKOFF_FACTOR = 1.6 # unitless ratio
2831
MAXIMUM_BACKOFF_FACTOR = 2.0 # unitless ratio
2932
MAXIMUM_BACKOFF_TIMEOUT = 60 # seconds
33+
CONNECTION_HEARTBEAT = 20 # seconds
3034

3135

3236
class TooFastSlowDown(Exception):
@@ -35,7 +39,10 @@ class TooFastSlowDown(Exception):
3539

3640
class IndexerDaemonControl:
3741
def __init__(self, celery_app, *, daemonthread_context=None, stop_event=None):
38-
self.celery_app = celery_app
42+
self.kombu_connection = kombu.Connection(
43+
celery_app.conf.broker_url, # use celery_app.conf for consistent config
44+
heartbeat=CONNECTION_HEARTBEAT,
45+
)
3946
self.daemonthread_context = daemonthread_context
4047
self._daemonthreads = []
4148
# shared stop_event for all threads below
@@ -50,10 +57,16 @@ def start_daemonthreads_for_strategy(self, index_strategy):
5057
)
5158
# spin up daemonthreads, ready for messages
5259
self._daemonthreads.extend(_daemon.start())
53-
# assign a thread to pass messages to this daemon
54-
threading.Thread(
55-
target=CeleryMessageConsumer(self.celery_app, _daemon).run,
56-
).start()
60+
_consumer = KombuMessageConsumer(
61+
kombu_connection=self.kombu_connection.clone(),
62+
stop_event=self.stop_event,
63+
index_strategy=index_strategy,
64+
message_callback=_daemon.on_message,
65+
)
66+
# give the daemon direct access to the connection, for acking purposes
67+
_daemon.ack_callback = _consumer.ensure_ack
68+
# assign a thread for the consumer to receive and enqueue messages to this daemon
69+
threading.Thread(target=_consumer.run).start()
5770
return _daemon
5871

5972
def start_all_daemonthreads(self):
@@ -67,18 +80,16 @@ def stop_daemonthreads(self, *, wait=False):
6780
_thread.join()
6881

6982

70-
class CeleryMessageConsumer(ConsumerMixin):
83+
class KombuMessageConsumer(ConsumerMixin):
7184
PREFETCH_COUNT = 7500
7285

73-
# (from ConsumerMixin)
74-
# should_stop: bool
86+
should_stop: bool # (from ConsumerMixin)
7587

76-
def __init__(self, celery_app, indexer_daemon):
77-
self.connection = celery_app.pool.acquire(block=True)
78-
self.celery_app = celery_app
79-
self.__stop_event = indexer_daemon.stop_event
80-
self.__message_callback = indexer_daemon.on_message
81-
self.__index_strategy = indexer_daemon.index_strategy
88+
def __init__(self, *, kombu_connection, stop_event, message_callback, index_strategy):
89+
self.connection = kombu_connection
90+
self.__stop_event = stop_event
91+
self.__message_callback = message_callback
92+
self.__index_strategy = index_strategy
8293

8394
# overrides ConsumerMixin.run
8495
def run(self):
@@ -112,9 +123,35 @@ def get_consumers(self, Consumer, channel):
112123
def __repr__(self):
113124
return '<{}({})>'.format(self.__class__.__name__, self.__index_strategy.name)
114125

126+
def consume(self, *args, **kwargs):
127+
# wrap `consume` in `kombu.Connection.ensure`, following guidance from
128+
# https://docs.celeryq.dev/projects/kombu/en/stable/userguide/failover.html#consumer
129+
consume = self.connection.ensure(self.connection, super().consume)
130+
return consume(*args, **kwargs)
131+
132+
def ensure_ack(self, daemon_message: messages.DaemonMessage):
133+
# if the connection the message came thru is no longer usable,
134+
# use `kombu.Connection.autoretry` and `kombu.Channel.basic_ack`
135+
# to ensure the ack goes thru
136+
try:
137+
daemon_message.ack()
138+
except (ConnectionError, amqp.exceptions.ConnectionError):
139+
@self.connection.autoretry
140+
def _do_ack(*, channel):
141+
try:
142+
channel.basic_ack(daemon_message.kombu_message.delivery_tag)
143+
finally:
144+
channel.close()
145+
_do_ack()
146+
147+
148+
def _default_ack_callback(daemon_message: messages.DaemonMessage) -> None:
149+
daemon_message.ack()
150+
115151

116152
class IndexerDaemon:
117153
MAX_LOCAL_QUEUE_SIZE = 5000
154+
ack_callback: Callable[[messages.DaemonMessage], None] = _default_ack_callback
118155

119156
def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None):
120157
self.stop_event = (
@@ -154,6 +191,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread:
154191
local_message_queue=_queue_from_rabbit_to_daemon,
155192
log_prefix=f'{repr(self)} MessageHandlingLoop: ',
156193
daemonthread_context=self.__daemonthread_context,
194+
ack_callback=self.ack_callback,
157195
)
158196
return _handling_loop.start_thread()
159197

@@ -186,7 +224,8 @@ class MessageHandlingLoop:
186224
stop_event: threading.Event
187225
local_message_queue: queue.Queue
188226
log_prefix: str
189-
daemonthread_context: contextlib.AbstractContextManager
227+
daemonthread_context: Callable[[], contextlib.AbstractContextManager]
228+
ack_callback: Callable[[messages.DaemonMessage], None]
190229
_leftover_daemon_messages_by_target_id = None
191230

192231
def __post_init__(self):
@@ -270,7 +309,7 @@ def _handle_some_messages(self):
270309
sentry_sdk.capture_message('error handling message', extras={'message_response': message_response})
271310
target_id = message_response.index_message.target_id
272311
for daemon_message in daemon_messages_by_target_id.pop(target_id, ()):
273-
daemon_message.ack() # finally set it free
312+
self.ack_callback(daemon_message)
274313
if daemon_messages_by_target_id: # should be empty by now
275314
logger.error('%sUnhandled messages?? %s', self.log_prefix, len(daemon_messages_by_target_id))
276315
sentry_sdk.capture_message(

0 commit comments

Comments
 (0)