Skip to content

Commit 35d3f5b

Browse files
author
Emanuele Palazzetti
authored
Merge pull request #297 from palazzem/asyncio-propagation
[asyncio] improved context propagation
2 parents af9b580 + 8401951 commit 35d3f5b

File tree

5 files changed

+125
-107
lines changed

5 files changed

+125
-107
lines changed

ddtrace/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ class Context(object):
2020
2121
This data structure is thread-safe.
2222
"""
23-
def __init__(self, trace_id=None, span_id=None):
23+
def __init__(self, trace_id=None, span_id=None, sampled=True):
2424
"""
2525
Initialize a new thread-safe ``Context``.
2626
2727
:param int trace_id: trace_id of parent span
2828
:param int span_id: span_id of parent span
2929
"""
3030
self._trace = []
31-
self._sampled = False
31+
self._sampled = sampled
3232
self._finished_spans = 0
3333
self._current_span = None
3434
self._lock = threading.Lock()

ddtrace/contrib/asyncio/helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ def _wrap_executor(fn, args, tracer, ctx):
7979

8080

8181
def create_task(*args, **kwargs):
82-
""" This method will enable spawned tasks to parent to the base task context """
83-
return _wrapped_create_task(_orig_create_task, None, args, kwargs)
82+
"""This function spawns a task with a Context that inherits the
83+
`trace_id` and the `parent_id` from the current active one if available.
84+
"""
85+
loop = asyncio.get_event_loop()
86+
return _wrapped_create_task(loop.create_task, None, args, kwargs)
8487

8588

8689
def _wrapped_create_task(wrapped, instance, args, kwargs):

ddtrace/contrib/asyncio/patch.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,27 @@
1-
# project
2-
import ddtrace
3-
from ddtrace.util import unwrap
4-
from ddtrace.provider import DefaultContextProvider
5-
6-
# 3p
7-
import wrapt
81
import asyncio
92

10-
from .helpers import _wrapped_create_task
11-
from . import context_provider
3+
from wrapt import wrap_function_wrapper as _w
124

13-
_orig_create_task = asyncio.BaseEventLoop.create_task
5+
from .helpers import _wrapped_create_task
6+
from ...util import unwrap as _u
147

158

16-
def patch(tracer=ddtrace.tracer):
17-
"""
18-
Patches `BaseEventLoop.create_task` to enable spawned tasks to parent to
19-
the base task context. Will also enable the asyncio task context.
9+
def patch():
10+
"""Patches current loop `create_task()` method to enable spawned tasks to
11+
parent to the base task context.
2012
"""
21-
# TODO: figure what to do with helpers.ensure_future and
22-
# helpers.run_in_executor (doesn't work for ProcessPoolExecutor)
2313
if getattr(asyncio, '_datadog_patch', False):
2414
return
2515
setattr(asyncio, '_datadog_patch', True)
2616

27-
tracer.configure(context_provider=context_provider)
28-
wrapt.wrap_function_wrapper('asyncio', 'BaseEventLoop.create_task', _wrapped_create_task)
17+
loop = asyncio.get_event_loop()
18+
_w(loop, 'create_task', _wrapped_create_task)
2919

3020

31-
def unpatch(tracer=ddtrace.tracer):
32-
"""
33-
Remove tracing from patched modules.
34-
"""
21+
def unpatch():
22+
"""Remove tracing from patched modules."""
23+
3524
if getattr(asyncio, '_datadog_patch', False):
3625
setattr(asyncio, '_datadog_patch', False)
37-
38-
tracer.configure(context_provider=DefaultContextProvider())
39-
unwrap(asyncio.BaseEventLoop, 'create_task')
26+
loop = asyncio.get_event_loop()
27+
_u(loop, 'create_task')

tests/contrib/asyncio/test_helpers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def future_work():
3131
eq_('coroutine', ctx._trace[0].name)
3232
return ctx._trace[0].name
3333

34-
span = self.tracer.trace('coroutine')
34+
self.tracer.trace('coroutine')
3535
# schedule future work and wait for a result
3636
delayed_task = helpers.ensure_future(future_work(), tracer=self.tracer)
3737
result = yield from asyncio.wait_for(delayed_task, timeout=1)
@@ -67,3 +67,21 @@ def future_work():
6767
span.finish()
6868
result = yield from future
6969
ok_(result)
70+
71+
@mark_asyncio
72+
def test_create_task(self):
73+
# the helper should create a new Task that has the Context attached
74+
@asyncio.coroutine
75+
def future_work():
76+
# the ctx is available in this task
77+
ctx = self.tracer.get_call_context()
78+
eq_(0, len(ctx._trace))
79+
child_span = self.tracer.trace('child_task')
80+
return child_span
81+
82+
root_span = self.tracer.trace('main_task')
83+
# schedule future work and wait for a result
84+
task = helpers.create_task(future_work())
85+
result = yield from task
86+
eq_(root_span.trace_id, result.trace_id)
87+
eq_(root_span.span_id, result.parent_id)

tests/contrib/asyncio/test_tracer.py

Lines changed: 86 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import asyncio
2+
23
from asyncio import BaseEventLoop
34

45
from ddtrace.context import Context
5-
from ddtrace.contrib.asyncio.helpers import set_call_context
6-
from ddtrace.contrib.asyncio.patch import patch, unpatch
7-
from ddtrace.contrib.asyncio import context_provider
86
from ddtrace.provider import DefaultContextProvider
7+
from ddtrace.contrib.asyncio.patch import patch, unpatch
8+
from ddtrace.contrib.asyncio.helpers import set_call_context
99

1010
from nose.tools import eq_, ok_
11-
1211
from .utils import AsyncioTestCase, mark_asyncio
1312

13+
1414
_orig_create_task = BaseEventLoop.create_task
1515

1616

1717
class TestAsyncioTracer(AsyncioTestCase):
18-
"""
19-
Ensure that the ``AsyncioTracer`` works for asynchronous execution
20-
within the same ``IOLoop``.
18+
"""Ensure that the tracer works with asynchronous executions within
19+
the same ``IOLoop``.
2120
"""
2221
@mark_asyncio
2322
def test_get_call_context(self):
@@ -204,92 +203,102 @@ def f1():
204203
span = spans[0]
205204
ok_(span.duration > 0.25, msg='span.duration={}'.format(span.duration))
206205

207-
@mark_asyncio
208-
def test_patch_chain(self):
209-
patch(self.tracer)
210-
211-
assert self.tracer._context_provider is context_provider
212-
213-
with self.tracer.trace('foo'):
214-
@self.tracer.wrap('f1')
215-
@asyncio.coroutine
216-
def f1():
217-
yield from asyncio.sleep(0.1)
218-
219-
@self.tracer.wrap('f2')
220-
@asyncio.coroutine
221-
def f2():
222-
yield from asyncio.ensure_future(f1())
223-
224-
yield from asyncio.ensure_future(f2())
225-
226-
traces = list(reversed(self.tracer.writer.pop_traces()))
227-
assert len(traces) == 3
228-
root_span = traces[0][0]
229-
last_span_id = None
230-
for trace in traces:
231-
assert len(trace) == 1
232-
span = trace[0]
233-
assert span.trace_id == root_span.trace_id
234-
assert span.parent_id == last_span_id
235-
last_span_id = span.span_id
206+
207+
class TestAsyncioPropagation(AsyncioTestCase):
208+
"""Ensure that asyncio context propagation works between different tasks"""
209+
def setUp(self):
210+
# patch asyncio event loop
211+
super(TestAsyncioPropagation, self).setUp()
212+
patch()
213+
214+
def tearDown(self):
215+
# unpatch asyncio event loop
216+
super(TestAsyncioPropagation, self).tearDown()
217+
unpatch()
236218

237219
@mark_asyncio
238-
def test_patch_parallel(self):
239-
patch(self.tracer)
220+
def test_tasks_chaining(self):
221+
# ensures that the context is propagated between different tasks
222+
@self.tracer.wrap('spawn_task')
223+
@asyncio.coroutine
224+
def coro_2():
225+
yield from asyncio.sleep(0.01)
226+
227+
@self.tracer.wrap('main_task')
228+
@asyncio.coroutine
229+
def coro_1():
230+
yield from asyncio.ensure_future(coro_2())
240231

241-
assert self.tracer._context_provider is context_provider
232+
yield from coro_1()
242233

243-
with self.tracer.trace('foo'):
244-
@self.tracer.wrap('f1')
245-
@asyncio.coroutine
246-
def f1():
247-
yield from asyncio.sleep(0.1)
234+
traces = self.tracer.writer.pop_traces()
235+
eq_(len(traces), 2)
236+
eq_(len(traces[0]), 1)
237+
eq_(len(traces[1]), 1)
238+
spawn_task = traces[0][0]
239+
main_task = traces[1][0]
240+
# check if the context has been correctly propagated
241+
eq_(spawn_task.trace_id, main_task.trace_id)
242+
eq_(spawn_task.parent_id, main_task.span_id)
248243

249-
@self.tracer.wrap('f2')
250-
@asyncio.coroutine
251-
def f2():
252-
yield from asyncio.sleep(0.1)
244+
@mark_asyncio
245+
def test_concurrent_chaining(self):
246+
# ensures that the context is correctly propagated when
247+
# concurrent tasks are created from a common tracing block
248+
@self.tracer.wrap('f1')
249+
@asyncio.coroutine
250+
def f1():
251+
yield from asyncio.sleep(0.01)
253252

253+
@self.tracer.wrap('f2')
254+
@asyncio.coroutine
255+
def f2():
256+
yield from asyncio.sleep(0.01)
257+
258+
with self.tracer.trace('main_task'):
254259
yield from asyncio.gather(f1(), f2())
255260

256261
traces = self.tracer.writer.pop_traces()
257-
assert len(traces) == 3
258-
root_span = traces[2][0]
259-
for trace in traces[:2]:
260-
assert len(trace) == 1
261-
span = trace[0]
262-
assert span.trace_id == root_span.trace_id
263-
assert span.parent_id == root_span.span_id
262+
eq_(len(traces), 3)
263+
eq_(len(traces[0]), 1)
264+
eq_(len(traces[1]), 1)
265+
eq_(len(traces[2]), 1)
266+
child_1 = traces[0][0]
267+
child_2 = traces[1][0]
268+
main_task = traces[2][0]
269+
# check if the context has been correctly propagated
270+
eq_(child_1.trace_id, main_task.trace_id)
271+
eq_(child_1.parent_id, main_task.span_id)
272+
eq_(child_2.trace_id, main_task.trace_id)
273+
eq_(child_2.parent_id, main_task.span_id)
264274

265275
@mark_asyncio
266-
def test_distributed(self):
267-
patch(self.tracer)
268-
276+
def test_propagation_with_new_context(self):
277+
# ensures that if a new Context is attached to the current
278+
# running Task, a previous trace is resumed
269279
task = asyncio.Task.current_task()
270280
ctx = Context(trace_id=100, span_id=101)
271281
set_call_context(task, ctx)
272282

273-
with self.tracer.trace('foo'):
274-
pass
283+
with self.tracer.trace('async_task'):
284+
yield from asyncio.sleep(0.01)
275285

276286
traces = self.tracer.writer.pop_traces()
277-
assert len(traces) == 1
278-
trace = traces[0]
279-
assert len(trace) == 1
280-
span = trace[0]
281-
282-
assert span.trace_id == ctx._parent_trace_id
283-
assert span.parent_id == ctx._parent_span_id
287+
eq_(len(traces), 1)
288+
eq_(len(traces[0]), 1)
289+
span = traces[0][0]
290+
eq_(span.trace_id, 100)
291+
eq_(span.parent_id, 101)
284292

285293
@mark_asyncio
286-
def test_unpatch(self):
287-
patch(self.tracer)
288-
unpatch(self.tracer)
289-
290-
assert isinstance(self.tracer._context_provider, DefaultContextProvider)
291-
assert BaseEventLoop.create_task == _orig_create_task
292-
293-
def test_double_patch(self):
294-
patch(self.tracer)
295-
self.test_patch_chain()
294+
def test_event_loop_unpatch(self):
295+
# ensures that the event loop can be unpatched
296+
unpatch()
297+
ok_(isinstance(self.tracer._context_provider, DefaultContextProvider))
298+
ok_(BaseEventLoop.create_task == _orig_create_task)
299+
300+
def test_event_loop_double_patch(self):
301+
# ensures that double patching will not double instrument
302+
# the event loop
303+
patch()
304+
self.test_tasks_chaining()

0 commit comments

Comments
 (0)