|
1 | 1 | import asyncio |
| 2 | + |
2 | 3 | from asyncio import BaseEventLoop |
3 | 4 |
|
4 | 5 | from ddtrace.context import Context |
5 | | -from ddtrace.contrib.asyncio.helpers import set_call_context |
6 | | -from ddtrace.contrib.asyncio.patch import patch, unpatch |
7 | | -from ddtrace.contrib.asyncio import context_provider |
8 | 6 | from ddtrace.provider import DefaultContextProvider |
| 7 | +from ddtrace.contrib.asyncio.patch import patch, unpatch |
| 8 | +from ddtrace.contrib.asyncio.helpers import set_call_context |
9 | 9 |
|
10 | 10 | from nose.tools import eq_, ok_ |
11 | | - |
12 | 11 | from .utils import AsyncioTestCase, mark_asyncio |
13 | 12 |
|
| 13 | + |
14 | 14 | _orig_create_task = BaseEventLoop.create_task |
15 | 15 |
|
16 | 16 |
|
17 | 17 | class TestAsyncioTracer(AsyncioTestCase): |
18 | | - """ |
19 | | - Ensure that the ``AsyncioTracer`` works for asynchronous execution |
20 | | - within the same ``IOLoop``. |
| 18 | + """Ensure that the tracer works with asynchronous executions within |
| 19 | + the same ``IOLoop``. |
21 | 20 | """ |
22 | 21 | @mark_asyncio |
23 | 22 | def test_get_call_context(self): |
@@ -204,92 +203,102 @@ def f1(): |
204 | 203 | span = spans[0] |
205 | 204 | ok_(span.duration > 0.25, msg='span.duration={}'.format(span.duration)) |
206 | 205 |
|
207 | | - @mark_asyncio |
208 | | - def test_patch_chain(self): |
209 | | - patch(self.tracer) |
210 | | - |
211 | | - assert self.tracer._context_provider is context_provider |
212 | | - |
213 | | - with self.tracer.trace('foo'): |
214 | | - @self.tracer.wrap('f1') |
215 | | - @asyncio.coroutine |
216 | | - def f1(): |
217 | | - yield from asyncio.sleep(0.1) |
218 | | - |
219 | | - @self.tracer.wrap('f2') |
220 | | - @asyncio.coroutine |
221 | | - def f2(): |
222 | | - yield from asyncio.ensure_future(f1()) |
223 | | - |
224 | | - yield from asyncio.ensure_future(f2()) |
225 | | - |
226 | | - traces = list(reversed(self.tracer.writer.pop_traces())) |
227 | | - assert len(traces) == 3 |
228 | | - root_span = traces[0][0] |
229 | | - last_span_id = None |
230 | | - for trace in traces: |
231 | | - assert len(trace) == 1 |
232 | | - span = trace[0] |
233 | | - assert span.trace_id == root_span.trace_id |
234 | | - assert span.parent_id == last_span_id |
235 | | - last_span_id = span.span_id |
| 206 | + |
| 207 | +class TestAsyncioPropagation(AsyncioTestCase): |
| 208 | + """Ensure that asyncio context propagation works between different tasks""" |
| 209 | + def setUp(self): |
| 210 | + # patch asyncio event loop |
| 211 | + super(TestAsyncioPropagation, self).setUp() |
| 212 | + patch() |
| 213 | + |
| 214 | + def tearDown(self): |
| 215 | + # unpatch asyncio event loop |
| 216 | + super(TestAsyncioPropagation, self).tearDown() |
| 217 | + unpatch() |
236 | 218 |
|
237 | 219 | @mark_asyncio |
238 | | - def test_patch_parallel(self): |
239 | | - patch(self.tracer) |
| 220 | + def test_tasks_chaining(self): |
| 221 | + # ensures that the context is propagated between different tasks |
| 222 | + @self.tracer.wrap('spawn_task') |
| 223 | + @asyncio.coroutine |
| 224 | + def coro_2(): |
| 225 | + yield from asyncio.sleep(0.01) |
| 226 | + |
| 227 | + @self.tracer.wrap('main_task') |
| 228 | + @asyncio.coroutine |
| 229 | + def coro_1(): |
| 230 | + yield from asyncio.ensure_future(coro_2()) |
240 | 231 |
|
241 | | - assert self.tracer._context_provider is context_provider |
| 232 | + yield from coro_1() |
242 | 233 |
|
243 | | - with self.tracer.trace('foo'): |
244 | | - @self.tracer.wrap('f1') |
245 | | - @asyncio.coroutine |
246 | | - def f1(): |
247 | | - yield from asyncio.sleep(0.1) |
| 234 | + traces = self.tracer.writer.pop_traces() |
| 235 | + eq_(len(traces), 2) |
| 236 | + eq_(len(traces[0]), 1) |
| 237 | + eq_(len(traces[1]), 1) |
| 238 | + spawn_task = traces[0][0] |
| 239 | + main_task = traces[1][0] |
| 240 | + # check if the context has been correctly propagated |
| 241 | + eq_(spawn_task.trace_id, main_task.trace_id) |
| 242 | + eq_(spawn_task.parent_id, main_task.span_id) |
248 | 243 |
|
249 | | - @self.tracer.wrap('f2') |
250 | | - @asyncio.coroutine |
251 | | - def f2(): |
252 | | - yield from asyncio.sleep(0.1) |
| 244 | + @mark_asyncio |
| 245 | + def test_concurrent_chaining(self): |
| 246 | + # ensures that the context is correctly propagated when |
| 247 | + # concurrent tasks are created from a common tracing block |
| 248 | + @self.tracer.wrap('f1') |
| 249 | + @asyncio.coroutine |
| 250 | + def f1(): |
| 251 | + yield from asyncio.sleep(0.01) |
253 | 252 |
|
| 253 | + @self.tracer.wrap('f2') |
| 254 | + @asyncio.coroutine |
| 255 | + def f2(): |
| 256 | + yield from asyncio.sleep(0.01) |
| 257 | + |
| 258 | + with self.tracer.trace('main_task'): |
254 | 259 | yield from asyncio.gather(f1(), f2()) |
255 | 260 |
|
256 | 261 | traces = self.tracer.writer.pop_traces() |
257 | | - assert len(traces) == 3 |
258 | | - root_span = traces[2][0] |
259 | | - for trace in traces[:2]: |
260 | | - assert len(trace) == 1 |
261 | | - span = trace[0] |
262 | | - assert span.trace_id == root_span.trace_id |
263 | | - assert span.parent_id == root_span.span_id |
| 262 | + eq_(len(traces), 3) |
| 263 | + eq_(len(traces[0]), 1) |
| 264 | + eq_(len(traces[1]), 1) |
| 265 | + eq_(len(traces[2]), 1) |
| 266 | + child_1 = traces[0][0] |
| 267 | + child_2 = traces[1][0] |
| 268 | + main_task = traces[2][0] |
| 269 | + # check if the context has been correctly propagated |
| 270 | + eq_(child_1.trace_id, main_task.trace_id) |
| 271 | + eq_(child_1.parent_id, main_task.span_id) |
| 272 | + eq_(child_2.trace_id, main_task.trace_id) |
| 273 | + eq_(child_2.parent_id, main_task.span_id) |
264 | 274 |
|
265 | 275 | @mark_asyncio |
266 | | - def test_distributed(self): |
267 | | - patch(self.tracer) |
268 | | - |
| 276 | + def test_propagation_with_new_context(self): |
| 277 | + # ensures that if a new Context is attached to the current |
| 278 | + # running Task, a previous trace is resumed |
269 | 279 | task = asyncio.Task.current_task() |
270 | 280 | ctx = Context(trace_id=100, span_id=101) |
271 | 281 | set_call_context(task, ctx) |
272 | 282 |
|
273 | | - with self.tracer.trace('foo'): |
274 | | - pass |
| 283 | + with self.tracer.trace('async_task'): |
| 284 | + yield from asyncio.sleep(0.01) |
275 | 285 |
|
276 | 286 | traces = self.tracer.writer.pop_traces() |
277 | | - assert len(traces) == 1 |
278 | | - trace = traces[0] |
279 | | - assert len(trace) == 1 |
280 | | - span = trace[0] |
281 | | - |
282 | | - assert span.trace_id == ctx._parent_trace_id |
283 | | - assert span.parent_id == ctx._parent_span_id |
| 287 | + eq_(len(traces), 1) |
| 288 | + eq_(len(traces[0]), 1) |
| 289 | + span = traces[0][0] |
| 290 | + eq_(span.trace_id, 100) |
| 291 | + eq_(span.parent_id, 101) |
284 | 292 |
|
285 | 293 | @mark_asyncio |
286 | | - def test_unpatch(self): |
287 | | - patch(self.tracer) |
288 | | - unpatch(self.tracer) |
289 | | - |
290 | | - assert isinstance(self.tracer._context_provider, DefaultContextProvider) |
291 | | - assert BaseEventLoop.create_task == _orig_create_task |
292 | | - |
293 | | - def test_double_patch(self): |
294 | | - patch(self.tracer) |
295 | | - self.test_patch_chain() |
| 294 | + def test_event_loop_unpatch(self): |
| 295 | + # ensures that the event loop can be unpatched |
| 296 | + unpatch() |
| 297 | + ok_(isinstance(self.tracer._context_provider, DefaultContextProvider)) |
| 298 | + ok_(BaseEventLoop.create_task == _orig_create_task) |
| 299 | + |
| 300 | + def test_event_loop_double_patch(self): |
| 301 | + # ensures that double patching will not double instrument |
| 302 | + # the event loop |
| 303 | + patch() |
| 304 | + self.test_tasks_chaining() |
0 commit comments