Skip to content

Commit 3cfce55

Browse files
lrafeeimergify[bot]TimPansino
authored
Celery (re)instrumentation (#1429)
* Add support for custom task classes * Initial commit * Distributed Tracing tests * Fix remaining tests and add custom task tests * Clean up comments in file * Ruff format * Fix comment typo * Reviewer suggestions and more tests * Clean up commented code * Add comments and shuffle code * Fix celery linter errors * Fix azure linter errors * Fix linter error --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Tim Pansino <[email protected]>
1 parent 02d0071 commit 3cfce55

File tree

8 files changed

+617
-207
lines changed

8 files changed

+617
-207
lines changed

newrelic/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,18 +4010,13 @@ def _process_module_builtin_defaults():
40104010
"instrument_rest_framework_decorators",
40114011
)
40124012

4013-
_process_module_definition("celery.task.base", "newrelic.hooks.application_celery", "instrument_celery_app_task")
4014-
_process_module_definition("celery.app.task", "newrelic.hooks.application_celery", "instrument_celery_app_task")
4013+
_process_module_definition("celery.local", "newrelic.hooks.application_celery", "instrument_celery_local")
40154014
_process_module_definition("celery.app.trace", "newrelic.hooks.application_celery", "instrument_celery_app_trace")
40164015
_process_module_definition("celery.worker", "newrelic.hooks.application_celery", "instrument_celery_worker")
4017-
_process_module_definition(
4018-
"celery.concurrency.processes", "newrelic.hooks.application_celery", "instrument_celery_worker"
4019-
)
40204016
_process_module_definition(
40214017
"celery.concurrency.prefork", "newrelic.hooks.application_celery", "instrument_celery_worker"
40224018
)
40234019

4024-
_process_module_definition("celery.app.base", "newrelic.hooks.application_celery", "instrument_celery_app_base")
40254020
_process_module_definition("billiard.pool", "newrelic.hooks.application_celery", "instrument_billiard_pool")
40264021

40274022
_process_module_definition("flup.server.cgi", "newrelic.hooks.adapter_flup", "instrument_flup_server_cgi")

newrelic/hooks/application_celery.py

Lines changed: 127 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from newrelic.api.message_trace import MessageTrace
2929
from newrelic.api.pre_function import wrap_pre_function
3030
from newrelic.api.transaction import current_transaction
31-
from newrelic.common.object_wrapper import FunctionWrapper, _NRBoundFunctionWrapper, wrap_function_wrapper
31+
from newrelic.common.object_wrapper import FunctionWrapper, wrap_function_wrapper
32+
from newrelic.common.signature import bind_args
3233
from newrelic.core.agent import shutdown_agent
3334

3435
UNKNOWN_TASK_NAME = "<Unknown Task>"
@@ -63,6 +64,10 @@ def task_info(instance, *args, **kwargs):
6364
return task_name, task_source
6465

6566

67+
# =============
68+
# Celery instrumentation for direct task calls (__call__ or run)
69+
70+
6671
def CeleryTaskWrapper(wrapped):
6772
def wrapper(wrapped, instance, args, kwargs):
6873
transaction = current_transaction(active_only=False)
@@ -103,15 +108,26 @@ def wrapper(wrapped, instance, args, kwargs):
103108
# Attempt to grab distributed tracing headers
104109
try:
105110
# Headers on earlier versions of Celery may end up as attributes
106-
# on the request context instead of as custom headers. Handler this
107-
# by defaulting to using vars() if headers is not available
108-
request = instance.request
111+
# on the request context instead of as custom headers. Handle this
112+
# by defaulting to using `vars()` if headers is not available
113+
114+
# If there is no request, the request property will return
115+
# a new instance of `celery.Context()` instead of `None`, so
116+
# this will be handled by accessing the request_stack directly.
117+
request = instance and instance.request_stack and instance.request_stack.top
109118
headers = getattr(request, "headers", None) or vars(request)
110119

111120
settings = transaction.settings
112121
if headers is not None and settings is not None:
113122
if settings.distributed_tracing.enabled:
114-
transaction.accept_distributed_trace_headers(headers, transport_type="AMQP")
123+
# Generate DT headers if they do not already exist in the incoming request
124+
if not transaction.accept_distributed_trace_headers(headers, transport_type="AMQP"):
125+
try:
126+
dt_headers = MessageTrace.generate_request_headers(transaction)
127+
if dt_headers:
128+
headers.update(dict(dt_headers))
129+
except Exception:
130+
pass
115131
elif transaction.settings.cross_application_tracer.enabled:
116132
transaction._process_incoming_cat_headers(
117133
headers.get(MessageTrace.cat_id_key, None),
@@ -139,7 +155,7 @@ def wrapper(wrapped, instance, args, kwargs):
139155
# Celery has included a monkey-patching provision which did not perform this
140156
# optimization on functions that were monkey-patched. Unfortunately, our
141157
# wrappers are too transparent for celery to detect that they've even been
142-
# monky-patched. To circumvent this, we set the __module__ of our wrapped task
158+
# monkey-patched. To circumvent this, we set the __module__ of our wrapped task
143159
# to this file which causes celery to properly detect that it has been patched.
144160
#
145161
# For versions of celery 2.5.3 to 2.5.5
@@ -159,85 +175,114 @@ def run(self, *args, **kwargs):
159175
return wrapped_task
160176

161177

162-
def instrument_celery_app_task(module):
163-
# Triggered for both 'celery.app.task' and 'celery.task.base'.
178+
def instrument_celery_local(module):
179+
if hasattr(module, "Proxy"):
180+
# This is used in the case where the function is
181+
# called directly on the Proxy object (rather than
182+
# using `delay` or `apply_async`)
183+
module.Proxy.__call__ = CeleryTaskWrapper(module.Proxy.__call__)
164184

165-
if hasattr(module, "BaseTask"):
166-
# Need to add a wrapper for background task entry point.
167185

168-
# In Celery 2.2 the 'BaseTask' class actually resided in the
169-
# module 'celery.task.base'. In Celery 2.3 the 'BaseTask' class
170-
# moved to 'celery.app.task' but an alias to it was retained in
171-
# the module 'celery.task.base'. We need to detect both module
172-
# imports, but we check the module name associated with
173-
# 'BaseTask' to ensure that we do not instrument the class via
174-
# the alias in Celery 2.3 and later.
186+
# =============
175187

176-
# In Celery 2.5+, although 'BaseTask' still exists execution of
177-
# the task doesn't pass through it. For Celery 2.5+ need to wrap
178-
# the tracer instead.
188+
# =============
189+
# Celery Instrumentation for delay/apply_async/apply:
179190

180-
if module.BaseTask.__module__ == module.__name__:
181-
module.BaseTask.__call__ = CeleryTaskWrapper(module.BaseTask.__call__)
182191

192+
def wrap_task_call(wrapped, instance, args, kwargs):
193+
transaction = current_transaction(active_only=False)
183194

184-
def wrap_Celery_send_task(wrapped, instance, args, kwargs):
185-
transaction = current_transaction()
186-
if not transaction:
187-
return wrapped(*args, **kwargs)
195+
# Grab task name and source
196+
_name, _source = task_info(wrapped, *args, **kwargs)
188197

189-
# Merge distributed tracing headers into outgoing task headers
190-
try:
191-
dt_headers = MessageTrace.generate_request_headers(transaction)
192-
original_headers = kwargs.get("headers", None)
193-
if dt_headers:
194-
if not original_headers:
195-
kwargs["headers"] = dict(dt_headers)
196-
else:
197-
kwargs["headers"] = dt_headers = dict(dt_headers)
198-
dt_headers.update(dict(original_headers))
199-
except Exception:
200-
pass
201-
202-
return wrapped(*args, **kwargs)
203-
204-
205-
def wrap_worker_optimizations(wrapped, instance, args, kwargs):
206-
# Attempt to uninstrument BaseTask before stack protection is installed or uninstalled
207-
try:
208-
from celery.app.task import BaseTask
198+
# A Celery Task can be called either outside of a transaction, or
199+
# within the context of an existing transaction. There are 3
200+
# possibilities we need to handle:
201+
#
202+
# 1. In an inactive transaction
203+
#
204+
# If the end_of_transaction() or ignore_transaction() API calls
205+
# have been invoked, this task may be called in the context
206+
# of an inactive transaction. In this case, don't wrap the task
207+
# in any way. Just run the original function.
208+
#
209+
# 2. In an active transaction
210+
#
211+
# Run the original function inside a FunctionTrace.
212+
#
213+
# 3. Outside of a transaction
214+
#
215+
# This is the typical case for a celery Task. Since it's not
216+
# running inside of an existing transaction, we want to create
217+
# a new background transaction for it.
209218

210-
if isinstance(BaseTask.__call__, _NRBoundFunctionWrapper):
211-
BaseTask.__call__ = BaseTask.__call__.__wrapped__
212-
except Exception:
213-
BaseTask = None
219+
if transaction and (transaction.ignore_transaction or transaction.stopped):
220+
return wrapped(*args, **kwargs)
214221

215-
# Allow metaprogramming to run
216-
result = wrapped(*args, **kwargs)
222+
elif transaction:
223+
with FunctionTrace(_name, source=_source):
224+
return wrapped(*args, **kwargs)
217225

218-
# Rewrap finalized BaseTask
219-
if BaseTask: # Ensure imports succeeded
220-
BaseTask.__call__ = CeleryTaskWrapper(BaseTask.__call__)
226+
else:
227+
with BackgroundTask(application_instance(), _name, "Celery", source=_source) as transaction:
228+
# Attempt to grab distributed tracing headers
229+
try:
230+
# Headers on earlier versions of Celery may end up as attributes
231+
# on the request context instead of as custom headers. Handle this
232+
# by defaulting to using `vars()` if headers is not available
233+
234+
# If there is no request, the request property will return
235+
# a new instance of `celery.Context()` instead of `None`, so
236+
# this will be handled by accessing the request_stack directly.
237+
request = wrapped and wrapped.request_stack and wrapped.request_stack.top
238+
headers = getattr(request, "headers", None) or vars(request)
239+
240+
settings = transaction.settings
241+
if headers is not None and settings is not None:
242+
if settings.distributed_tracing.enabled:
243+
# Generate DT headers if they do not already exist in the incoming request
244+
if not transaction.accept_distributed_trace_headers(headers, transport_type="AMQP"):
245+
try:
246+
dt_headers = MessageTrace.generate_request_headers(transaction)
247+
if dt_headers:
248+
headers.update(dict(dt_headers))
249+
except Exception:
250+
pass
251+
elif transaction.settings.cross_application_tracer.enabled:
252+
transaction._process_incoming_cat_headers(
253+
headers.get(MessageTrace.cat_id_key, None),
254+
headers.get(MessageTrace.cat_transaction_key, None),
255+
)
256+
except Exception:
257+
pass
221258

222-
return result
259+
return wrapped(*args, **kwargs)
223260

224261

225-
def instrument_celery_app_base(module):
226-
if hasattr(module, "Celery") and hasattr(module.Celery, "send_task"):
227-
wrap_function_wrapper(module, "Celery.send_task", wrap_Celery_send_task)
262+
def wrap_build_tracer(wrapped, instance, args, kwargs):
263+
class TaskWrapper(FunctionWrapper):
264+
def run(self, *args, **kwargs):
265+
return self.__call__(*args, **kwargs)
228266

267+
try:
268+
bound_args = bind_args(wrapped, args, kwargs)
269+
task = bound_args.get("task", None)
229270

230-
def instrument_celery_worker(module):
231-
# Triggered for 'celery.worker' and 'celery.concurrency.processes'.
271+
task = TaskWrapper(task, wrap_task_call)
272+
# Reset __module__ to be less transparent so celery detects our monkey-patching
273+
task.__module__ = wrap_task_call.__module__
274+
bound_args["task"] = task
232275

233-
if hasattr(module, "process_initializer"):
234-
# We try and force registration of default application after
235-
# fork of worker process rather than lazily on first request.
276+
return wrapped(**bound_args)
277+
except:
278+
# If we can't bind the args, we just call the wrapped function
279+
return wrapped(*args, **kwargs)
236280

237-
# Originally the 'process_initializer' function was located in
238-
# 'celery.worker'. In Celery 2.5 the function 'process_initializer'
239-
# was moved to the module 'celery.concurrency.processes'.
240281

282+
def instrument_celery_worker(module):
283+
if hasattr(module, "process_initializer"):
284+
# We try and force activation of the agent before
285+
# the worker process starts.
241286
_process_initializer = module.process_initializer
242287

243288
@functools.wraps(module.process_initializer)
@@ -247,6 +292,18 @@ def process_initializer(*args, **kwargs):
247292

248293
module.process_initializer = process_initializer
249294

295+
if hasattr(module, "process_destructor"):
296+
# We try and force shutdown of the agent before
297+
# the worker process exits.
298+
_process_destructor = module.process_destructor
299+
300+
@functools.wraps(module.process_destructor)
301+
def process_destructor(*args, **kwargs):
302+
shutdown_agent()
303+
return _process_destructor(*args, **kwargs)
304+
305+
module.process_destructor = process_destructor
306+
250307

251308
def instrument_celery_loaders_base(module):
252309
def force_application_activation(*args, **kwargs):
@@ -259,14 +316,10 @@ def instrument_billiard_pool(module):
259316
def force_agent_shutdown(*args, **kwargs):
260317
shutdown_agent()
261318

262-
if hasattr(module, "Worker"):
319+
if hasattr(module, "Worker") and hasattr(module.Worker, "_do_exit"):
263320
wrap_pre_function(module, "Worker._do_exit", force_agent_shutdown)
264321

265322

266323
def instrument_celery_app_trace(module):
267-
# Uses same wrapper for setup and reset worker optimizations to prevent patching and unpatching from removing wrappers
268-
if hasattr(module, "setup_worker_optimizations"):
269-
wrap_function_wrapper(module, "setup_worker_optimizations", wrap_worker_optimizations)
270-
271-
if hasattr(module, "reset_worker_optimizations"):
272-
wrap_function_wrapper(module, "reset_worker_optimizations", wrap_worker_optimizations)
324+
if hasattr(module, "build_tracer"):
325+
wrap_function_wrapper(module, "build_tracer", wrap_build_tracer)

tests/application_celery/_target_application.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from celery import Celery, shared_task
16-
from testing_support.validators.validate_distributed_trace_accepted import validate_distributed_trace_accepted
15+
from celery import Celery, Task, shared_task
1716

1817
from newrelic.api.transaction import current_transaction
1918

@@ -47,11 +46,27 @@ def shared_task_add(x, y):
4746
return x + y
4847

4948

50-
@app.task
51-
@validate_distributed_trace_accepted(transport_type="AMQP")
52-
def assert_dt():
53-
# Basic checks for DT delegated to task
54-
txn = current_transaction()
55-
assert txn, "No transaction active."
56-
assert txn.name == "_target_application.assert_dt", f"Transaction name does not match: {txn.name}"
57-
return 1
49+
class CustomCeleryTaskWithSuper(Task):
50+
def __call__(self, *args, **kwargs):
51+
transaction = current_transaction()
52+
if transaction:
53+
transaction.add_custom_attribute("custom_task_attribute", "Called with super")
54+
return super().__call__(*args, **kwargs)
55+
56+
57+
class CustomCeleryTaskWithRun(Task):
58+
def __call__(self, *args, **kwargs):
59+
transaction = current_transaction()
60+
if transaction:
61+
transaction.add_custom_attribute("custom_task_attribute", "Called with run")
62+
return self.run(*args, **kwargs)
63+
64+
65+
@app.task(base=CustomCeleryTaskWithSuper)
66+
def add_with_super(x, y):
67+
return x + y
68+
69+
70+
@app.task(base=CustomCeleryTaskWithRun)
71+
def add_with_run(x, y):
72+
return x + y

tests/application_celery/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import pytest
16+
from celery.app.trace import reset_worker_optimizations, setup_worker_optimizations
1517
from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture
1618

1719
_default_settings = {
@@ -43,3 +45,12 @@ def celery_worker_parameters():
4345
@pytest.fixture(scope="session", autouse=True)
4446
def celery_worker_available(celery_session_worker):
4547
return celery_session_worker
48+
49+
50+
@pytest.fixture(scope="session", autouse=True, params=[False, True], ids=["unpatched", "patched"])
51+
def with_worker_optimizations(request, celery_worker_available):
52+
if request.param:
53+
setup_worker_optimizations(celery_worker_available.app)
54+
55+
yield request.param
56+
reset_worker_optimizations()

0 commit comments

Comments
 (0)