Skip to content

Commit 5d3a565

Browse files
author
Emanuele Palazzetti
committed
Merge branch 'palazzem/celery-integration' into 'master'
2 parents 5a65259 + a553ed6 commit 5d3a565

25 files changed

+867
-1022
lines changed

ddtrace/bootstrap/sitecustomize.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def update_patched_modules():
7070
if opts:
7171
tracer.configure(**opts)
7272

73-
if not hasattr(sys, 'argv'):
74-
sys.argv = ['']
75-
7673
if patch:
7774
update_patched_modules()
7875
from ddtrace import patch_all; patch_all(**EXTRA_PATCHED_MODULES) # noqa

ddtrace/contrib/celery/__init__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
The Celery integration will trace all tasks that are executed in the
3-
background. To trace your Celery application, call the patch method::
3+
background. Functions and class based tasks are traced only if the Celery API
4+
is used, so calling the function directly or via the ``run()`` method will not
5+
generate traces. On the other hand, calling ``apply()`` and ``apply_async()``
6+
will produce tracing data. To trace your Celery application, call the patch method::
47
58
import celery
69
from ddtrace import patch
@@ -12,36 +15,33 @@
1215
def my_task():
1316
pass
1417
15-
1618
class MyTask(app.Task):
1719
def run(self):
1820
pass
1921
2022
21-
If you don't need to patch all Celery tasks, you can patch individual
22-
applications or tasks using a fine grain patching method::
23+
To change Celery service name, you can update the attached ``Pin``
24+
instance::
2325
24-
import celery
25-
from ddtrace.contrib.celery import patch_app, patch_task
26+
from ddtrace import Pin
2627
27-
# patch only this application
2828
app = celery.Celery()
29-
app = patch_app(app)
3029
31-
# or if you didn't patch the whole application, just patch
32-
# a single function or class based Task
3330
@app.task
34-
def fn_task():
31+
def compute_stats():
3532
pass
3633
34+
# globally
35+
Pin.override(app, service='background-jobs')
36+
37+
# by task
38+
Pin.override(compute_stats, service='data-processing')
3739
38-
class BaseClassTask(celery.Task):
39-
def run(self):
40-
pass
4140
41+
By default, reported service names are:
42+
* ``celery-producer`` when tasks are enqueued for processing
43+
* ``celery-worker`` when tasks are processed by a Celery process
4244
43-
BaseClassTask = patch_task(BaseClassTask)
44-
fn_task = patch_task(fn_task)
4545
"""
4646
from ...utils.importlib import require_modules
4747

ddtrace/contrib/celery/app.py

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,53 @@
1-
# Standard library
2-
import types
1+
from celery import signals
32

4-
# Third party
5-
import wrapt
6-
7-
# Project
83
from ddtrace import Pin
4+
from ddtrace.pin import _DD_PIN_NAME
95
from ddtrace.ext import AppTypes
10-
from .task import patch_task, unpatch_task
11-
from .util import APP, WORKER_SERVICE, require_pin
6+
7+
from .constants import APP, WORKER_SERVICE
8+
from .signals import (
9+
trace_prerun,
10+
trace_postrun,
11+
trace_before_publish,
12+
trace_after_publish,
13+
trace_failure,
14+
)
1215

1316

1417
def patch_app(app, pin=None):
15-
""" patch_app will add tracing to a celery app """
18+
"""Attach the Pin class to the application and connect
19+
our handlers to Celery signals.
20+
"""
21+
if getattr(app, '__datadog_patch', False):
22+
return
23+
setattr(app, '__datadog_patch', True)
24+
25+
# attach the PIN object
1626
pin = pin or Pin(service=WORKER_SERVICE, app=APP, app_type=AppTypes.worker)
17-
patch_methods = [
18-
('task', _app_task),
19-
]
20-
for method_name, wrapper in patch_methods:
21-
# Get the original method
22-
method = getattr(app, method_name, None)
23-
if method is None:
24-
continue
25-
26-
# Do not patch if method is already patched
27-
if isinstance(method, wrapt.ObjectProxy):
28-
continue
29-
30-
# Patch method
31-
setattr(app, method_name, wrapt.FunctionWrapper(method, wrapper))
32-
33-
# patch the Task class if available
34-
setattr(app, 'Task', patch_task(app.Task))
35-
36-
# Attach our pin to the app
3727
pin.onto(app)
28+
# connect to the Signal framework
29+
signals.task_prerun.connect(trace_prerun)
30+
signals.task_postrun.connect(trace_postrun)
31+
signals.before_task_publish.connect(trace_before_publish)
32+
signals.after_task_publish.connect(trace_after_publish)
33+
signals.task_failure.connect(trace_failure)
3834
return app
3935

4036

4137
def unpatch_app(app):
42-
""" unpatch_app will remove tracing from a celery app """
43-
patched_methods = [
44-
'task',
45-
]
46-
for method_name in patched_methods:
47-
# Get the wrapped method
48-
wrapper = getattr(app, method_name, None)
49-
if wrapper is None:
50-
continue
51-
52-
# Only unpatch if the wrapper is an `ObjectProxy`
53-
if not isinstance(wrapper, wrapt.ObjectProxy):
54-
continue
55-
56-
# Restore original method
57-
setattr(app, method_name, wrapper.__wrapped__)
58-
59-
# restore the original Task class
60-
setattr(app, 'Task', unpatch_task(app.Task))
61-
return app
62-
63-
64-
@require_pin
65-
def _app_task(pin, func, app, args, kwargs):
66-
task = func(*args, **kwargs)
67-
68-
# `app.task` is a decorator which may return a function wrapper
69-
if isinstance(task, types.FunctionType):
70-
def wrapper(func, instance, args, kwargs):
71-
return patch_task(func(*args, **kwargs), pin=pin)
72-
return wrapt.FunctionWrapper(task, wrapper)
73-
74-
return patch_task(task, pin=pin)
38+
"""Remove the Pin instance from the application and disconnect
39+
our handlers from Celery signal framework.
40+
"""
41+
if not getattr(app, '__datadog_patch', False):
42+
return
43+
setattr(app, '__datadog_patch', False)
44+
45+
pin = Pin.get_from(app)
46+
if pin is not None:
47+
delattr(app, _DD_PIN_NAME)
48+
49+
signals.task_prerun.disconnect(trace_prerun)
50+
signals.task_postrun.disconnect(trace_postrun)
51+
signals.before_task_publish.disconnect(trace_before_publish)
52+
signals.after_task_publish.disconnect(trace_after_publish)
53+
signals.task_failure.disconnect(trace_failure)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from os import getenv
2+
3+
# Celery Context key
4+
CTX_KEY = '__dd_task_span'
5+
6+
# Span names
7+
PRODUCER_ROOT_SPAN = 'celery.apply'
8+
WORKER_ROOT_SPAN = 'celery.run'
9+
10+
# Task operations
11+
TASK_TAG_KEY = 'celery.action'
12+
TASK_APPLY = 'apply'
13+
TASK_APPLY_ASYNC = 'apply_async'
14+
TASK_RUN = 'run'
15+
16+
# Service info
17+
APP = 'celery'
18+
PRODUCER_SERVICE = getenv('DATADOG_SERVICE_NAME') or 'celery-producer'
19+
WORKER_SERVICE = getenv('DATADOG_SERVICE_NAME') or 'celery-worker'

ddtrace/contrib/celery/patch.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
import celery
22

3-
from wrapt import wrap_function_wrapper as _w
4-
53
from .app import patch_app, unpatch_app
6-
from .task import _wrap_shared_task
7-
from .registry import _wrap_register
8-
from ...utils.wrappers import unwrap as _u
94

105

116
def patch():
@@ -14,13 +9,9 @@ def patch():
149
case of Django-Celery integration, also the `@shared_task` decorator
1510
must be instrumented because Django doesn't use the Celery registry.
1611
"""
17-
setattr(celery, 'Celery', patch_app(celery.Celery))
18-
_w('celery.app.registry', 'TaskRegistry.register', _wrap_register)
19-
_w('celery', 'shared_task', _wrap_shared_task)
12+
patch_app(celery.Celery)
2013

2114

2215
def unpatch():
23-
"""Removes instrumentation from Celery"""
24-
setattr(celery, 'Celery', unpatch_app(celery.Celery))
25-
_u(celery.app.registry.TaskRegistry, 'register')
26-
_u(celery, 'shared_task')
16+
"""Disconnect all signals and remove Tracing capabilities"""
17+
unpatch_app(celery.Celery)

ddtrace/contrib/celery/registry.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

ddtrace/contrib/celery/signals.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import logging
2+
3+
from ddtrace import Pin
4+
5+
from celery import registry
6+
7+
from . import constants as c
8+
from .utils import (
9+
tags_from_context,
10+
retrieve_task_id,
11+
attach_span,
12+
detach_span,
13+
retrieve_span,
14+
)
15+
16+
17+
log = logging.getLogger(__name__)
18+
19+
20+
def trace_prerun(*args, **kwargs):
21+
# safe-guard to avoid crashes in case the signals API
22+
# changes in Celery
23+
task = kwargs.get('sender')
24+
task_id = kwargs.get('task_id')
25+
if task is None or task_id is None:
26+
log.debug('unable to extract the Task and the task_id. This version of Celery may not be supported.')
27+
return
28+
29+
# retrieve the task Pin or fallback to the global one
30+
pin = Pin.get_from(task) or Pin.get_from(task.app)
31+
if pin is None:
32+
return
33+
34+
# propagate the `Span` in the current task Context
35+
span = pin.tracer.trace(c.WORKER_ROOT_SPAN, service=c.WORKER_SERVICE, resource=task.name)
36+
attach_span(task, task_id, span)
37+
38+
39+
def trace_postrun(*args, **kwargs):
40+
# safe-guard to avoid crashes in case the signals API
41+
# changes in Celery
42+
task = kwargs.get('sender')
43+
task_id = kwargs.get('task_id')
44+
if task is None or task_id is None:
45+
log.debug('unable to extract the Task and the task_id. This version of Celery may not be supported.')
46+
return
47+
48+
# retrieve and finish the Span
49+
span = retrieve_span(task, task_id)
50+
if span is None:
51+
return
52+
else:
53+
# request context tags
54+
span.set_tag(c.TASK_TAG_KEY, c.TASK_RUN)
55+
span.set_tags(tags_from_context(kwargs))
56+
span.set_tags(tags_from_context(task.request))
57+
span.finish()
58+
detach_span(task, task_id)
59+
60+
61+
def trace_before_publish(*args, **kwargs):
62+
# `before_task_publish` signal doesn't propagate the task instance so
63+
# we need to retrieve it from the Celery Registry to access the `Pin`. The
64+
# `Task` instance **does not** include any information about the current
65+
# execution, so it **must not** be used to retrieve `request` data.
66+
task_name = kwargs.get('sender')
67+
task = registry.tasks.get(task_name)
68+
task_id = retrieve_task_id(kwargs)
69+
# safe-guard to avoid crashes in case the signals API
70+
# changes in Celery
71+
if task is None or task_id is None:
72+
log.debug('unable to extract the Task and the task_id. This version of Celery may not be supported.')
73+
return
74+
75+
# propagate the `Span` in the current task Context
76+
pin = Pin.get_from(task) or Pin.get_from(task.app)
77+
if pin is None:
78+
return
79+
80+
# apply some tags here because most of the data is not available
81+
# in the task_after_publish signal
82+
span = pin.tracer.trace(c.PRODUCER_ROOT_SPAN, service=c.PRODUCER_SERVICE, resource=task_name)
83+
span.set_tag(c.TASK_TAG_KEY, c.TASK_APPLY_ASYNC)
84+
span.set_tag('celery.id', task_id)
85+
span.set_tags(tags_from_context(kwargs))
86+
# Note: adding tags from `traceback` or `state` calls will make an
87+
# API call to the backend for the properties so we should rely
88+
# only on the given `Context`
89+
attach_span(task, task_id, span)
90+
91+
92+
def trace_after_publish(*args, **kwargs):
93+
task_name = kwargs.get('sender')
94+
task = registry.tasks.get(task_name)
95+
task_id = retrieve_task_id(kwargs)
96+
# safe-guard to avoid crashes in case the signals API
97+
# changes in Celery
98+
if task is None or task_id is None:
99+
log.debug('unable to extract the Task and the task_id. This version of Celery may not be supported.')
100+
return
101+
102+
# retrieve and finish the Span
103+
span = retrieve_span(task, task_id)
104+
if span is None:
105+
return
106+
else:
107+
span.finish()
108+
detach_span(task, task_id)
109+
110+
111+
def trace_failure(*args, **kwargs):
112+
# safe-guard to avoid crashes in case the signals API
113+
# changes in Celery
114+
task = kwargs.get('sender')
115+
task_id = kwargs.get('task_id')
116+
if task is None or task_id is None:
117+
log.debug('unable to extract the Task and the task_id. This version of Celery may not be supported.')
118+
return
119+
120+
# retrieve and finish the Span
121+
span = retrieve_span(task, task_id)
122+
if span is None:
123+
return
124+
else:
125+
# add Exception tags; post signals are still called
126+
# so we don't need to attach other tags here
127+
ex = kwargs.get('einfo')
128+
if ex is None:
129+
return
130+
span.set_exc_info(ex.type, ex.exception, ex.tb)

0 commit comments

Comments
 (0)