11import contextlib
22import collections
3+ from collections .abc import Callable
34import dataclasses
45import logging
56import queue
67import random
78import threading
89import time
910
11+ import amqp .exceptions
1012from django .conf import settings
13+ import kombu
1114from kombu .mixins import ConsumerMixin
1215import sentry_sdk
1316
2730MINIMUM_BACKOFF_FACTOR = 1.6 # unitless ratio
2831MAXIMUM_BACKOFF_FACTOR = 2.0 # unitless ratio
2932MAXIMUM_BACKOFF_TIMEOUT = 60 # seconds
33+ CONNECTION_HEARTBEAT = 20 # seconds
3034
3135
3236class TooFastSlowDown (Exception ):
@@ -35,7 +39,10 @@ class TooFastSlowDown(Exception):
3539
3640class IndexerDaemonControl :
3741 def __init__ (self , celery_app , * , daemonthread_context = None , stop_event = None ):
38- self .celery_app = celery_app
42+ self .kombu_connection = kombu .Connection (
43+ celery_app .conf .broker_url , # use celery_app.conf for consistent config
44+ heartbeat = CONNECTION_HEARTBEAT ,
45+ )
3946 self .daemonthread_context = daemonthread_context
4047 self ._daemonthreads = []
4148 # shared stop_event for all threads below
@@ -50,10 +57,16 @@ def start_daemonthreads_for_strategy(self, index_strategy):
5057 )
5158 # spin up daemonthreads, ready for messages
5259 self ._daemonthreads .extend (_daemon .start ())
53- # assign a thread to pass messages to this daemon
54- threading .Thread (
55- target = CeleryMessageConsumer (self .celery_app , _daemon ).run ,
56- ).start ()
60+ _consumer = KombuMessageConsumer (
61+ kombu_connection = self .kombu_connection .clone (),
62+ stop_event = self .stop_event ,
63+ index_strategy = index_strategy ,
64+ message_callback = _daemon .on_message ,
65+ )
66+ # give the daemon direct access to the connection, for acking purposes
67+ _daemon .ack_callback = _consumer .ensure_ack
68+ # assign a thread for the consumer to receive and enqueue messages to this daemon
69+ threading .Thread (target = _consumer .run ).start ()
5770 return _daemon
5871
5972 def start_all_daemonthreads (self ):
@@ -67,18 +80,16 @@ def stop_daemonthreads(self, *, wait=False):
6780 _thread .join ()
6881
6982
70- class CeleryMessageConsumer (ConsumerMixin ):
83+ class KombuMessageConsumer (ConsumerMixin ):
7184 PREFETCH_COUNT = 7500
7285
73- # (from ConsumerMixin)
74- # should_stop: bool
86+ should_stop : bool # (from ConsumerMixin)
7587
76- def __init__ (self , celery_app , indexer_daemon ):
77- self .connection = celery_app .pool .acquire (block = True )
78- self .celery_app = celery_app
79- self .__stop_event = indexer_daemon .stop_event
80- self .__message_callback = indexer_daemon .on_message
81- self .__index_strategy = indexer_daemon .index_strategy
88+ def __init__ (self , * , kombu_connection , stop_event , message_callback , index_strategy ):
89+ self .connection = kombu_connection
90+ self .__stop_event = stop_event
91+ self .__message_callback = message_callback
92+ self .__index_strategy = index_strategy
8293
8394 # overrides ConsumerMixin.run
8495 def run (self ):
@@ -112,9 +123,35 @@ def get_consumers(self, Consumer, channel):
112123 def __repr__ (self ):
113124 return '<{}({})>' .format (self .__class__ .__name__ , self .__index_strategy .name )
114125
126+ def consume (self , * args , ** kwargs ):
127+ # wrap `consume` in `kombu.Connection.ensure`, following guidance from
128+ # https://docs.celeryq.dev/projects/kombu/en/stable/userguide/failover.html#consumer
129+ consume = self .connection .ensure (self .connection , super ().consume )
130+ return consume (* args , ** kwargs )
131+
132+ def ensure_ack (self , daemon_message : messages .DaemonMessage ):
133+ # if the connection the message came thru is no longer usable,
134+ # use `kombu.Connection.autoretry` and `kombu.Channel.basic_ack`
135+ # to ensure the ack goes thru
136+ try :
137+ daemon_message .ack ()
138+ except (ConnectionError , amqp .exceptions .ConnectionError ):
139+ @self .connection .autoretry
140+ def _do_ack (* , channel ):
141+ try :
142+ channel .basic_ack (daemon_message .kombu_message .delivery_tag )
143+ finally :
144+ channel .close ()
145+ _do_ack ()
146+
147+
148+ def _default_ack_callback (daemon_message : messages .DaemonMessage ) -> None :
149+ daemon_message .ack ()
150+
115151
116152class IndexerDaemon :
117153 MAX_LOCAL_QUEUE_SIZE = 5000
154+ ack_callback : Callable [[messages .DaemonMessage ], None ] = _default_ack_callback
118155
119156 def __init__ (self , index_strategy , * , stop_event = None , daemonthread_context = None ):
120157 self .stop_event = (
@@ -154,6 +191,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread:
154191 local_message_queue = _queue_from_rabbit_to_daemon ,
155192 log_prefix = f'{ repr (self )} MessageHandlingLoop: ' ,
156193 daemonthread_context = self .__daemonthread_context ,
194+ ack_callback = self .ack_callback ,
157195 )
158196 return _handling_loop .start_thread ()
159197
@@ -186,7 +224,8 @@ class MessageHandlingLoop:
186224 stop_event : threading .Event
187225 local_message_queue : queue .Queue
188226 log_prefix : str
189- daemonthread_context : contextlib .AbstractContextManager
227+ daemonthread_context : Callable [[], contextlib .AbstractContextManager ]
228+ ack_callback : Callable [[messages .DaemonMessage ], None ]
190229 _leftover_daemon_messages_by_target_id = None
191230
192231 def __post_init__ (self ):
@@ -270,7 +309,7 @@ def _handle_some_messages(self):
270309 sentry_sdk .capture_message ('error handling message' , extras = {'message_response' : message_response })
271310 target_id = message_response .index_message .target_id
272311 for daemon_message in daemon_messages_by_target_id .pop (target_id , ()):
273- daemon_message . ack () # finally set it free
312+ self . ack_callback ( daemon_message )
274313 if daemon_messages_by_target_id : # should be empty by now
275314 logger .error ('%sUnhandled messages?? %s' , self .log_prefix , len (daemon_messages_by_target_id ))
276315 sentry_sdk .capture_message (
0 commit comments