Skip to content

Commit 8e0e402

Browse files
fix(internal): add locks for thread-safety of trace encoder (#3022) (#3024)
* fix(internal): add buffered encoder test for thread safety * fix(internal): wrap encoder with reentrant lock * fix flake8 * add release note * add missing locks Co-authored-by: brettlangdon <[email protected]> (cherry picked from commit dbcc631) Co-authored-by: Tahir H. Butt <[email protected]>
1 parent a7f7b68 commit 8e0e402

File tree

3 files changed

+99
-54
lines changed

3 files changed

+99
-54
lines changed

ddtrace/internal/_encoding.pyx

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ cdef class MsgpackStringTable(StringTable):
208208
self.max_size = max_size
209209
self.pk.length = MSGPACK_STRING_TABLE_LENGTH_PREFIX_SIZE
210210
self._sp_len = 0
211-
self._lock = threading.Lock()
211+
self._lock = threading.RLock()
212212
super(MsgpackStringTable, self).__init__()
213213

214214
assert self.index(ORIGIN_KEY) == 1
@@ -264,7 +264,8 @@ cdef class MsgpackStringTable(StringTable):
264264

265265
@property
266266
def size(self):
267-
return self.pk.length - MSGPACK_ARRAY_LENGTH_PREFIX_SIZE + array_prefix_size(self._next_id)
267+
with self._lock:
268+
return self.pk.length - MSGPACK_ARRAY_LENGTH_PREFIX_SIZE + array_prefix_size(self._next_id)
268269

269270
cdef append_raw(self, long src, Py_ssize_t size):
270271
cdef int res
@@ -284,10 +285,11 @@ cdef class MsgpackStringTable(StringTable):
284285
self._sp_len = 0
285286

286287
cpdef flush(self):
287-
try:
288-
return self.get_bytes()
289-
finally:
290-
self.reset()
288+
with self._lock:
289+
try:
290+
return self.get_bytes()
291+
finally:
292+
self.reset()
291293

292294

293295
cdef class BufferedEncoder(object):
@@ -372,7 +374,7 @@ cdef class MsgpackEncoderBase(BufferedEncoder):
372374
self.max_size = max_size
373375
self.pk.buf_size = buf_size
374376
self.max_item_size = max_item_size if max_item_size < max_size else max_size
375-
self._lock = threading.Lock()
377+
self._lock = threading.RLock()
376378
self._reset_buffer()
377379

378380
def __dealloc__(self):
@@ -393,10 +395,11 @@ cdef class MsgpackEncoderBase(BufferedEncoder):
393395
self.pk.length = MSGPACK_ARRAY_LENGTH_PREFIX_SIZE # Leave room for array length prefix
394396

395397
cpdef encode(self):
396-
if not self._count:
397-
return None
398+
with self._lock:
399+
if not self._count:
400+
return None
398401

399-
return self.flush()
402+
return self.flush()
400403

401404
cdef inline int _update_array_len(self):
402405
"""Update traces array size prefix"""
@@ -437,45 +440,46 @@ cdef class MsgpackEncoderBase(BufferedEncoder):
437440
if L > 0 and trace[0].context is not None and trace[0].context.dd_origin is not None:
438441
dd_origin = self.get_dd_origin_ref(trace[0].context.dd_origin)
439442

440-
with self._lock:
441-
for span in trace:
442-
ret = self.pack_span(span, dd_origin)
443-
if ret != 0: raise RuntimeError("Couldn't pack span")
443+
for span in trace:
444+
ret = self.pack_span(span, dd_origin)
445+
if ret != 0: raise RuntimeError("Couldn't pack span")
444446

445447
return ret
446448

447449
cpdef put(self, list trace):
448450
"""Put a trace (i.e. a list of spans) in the buffer."""
449451
cdef int ret
450452

451-
len_before = self.pk.length
452-
size_before = self.size
453-
try:
454-
ret = self._pack_trace(trace)
455-
if ret: # should not happen.
456-
raise RuntimeError("internal error")
457-
458-
# DEV: msgpack avoids buffer overflows by calling PyMem_Realloc so
459-
# we must check sizes manually.
460-
# TODO: We should probably ensure that the buffer size doesn't
461-
# grow arbitrarily because of the PyMem_Realloc and if it does then
462-
# free and reallocate with the appropriate size.
463-
if self.size - size_before > self.max_item_size:
464-
raise BufferItemTooLarge(self.size - size_before)
465-
466-
if self.size > self.max_size:
467-
raise BufferFull(self.size - size_before)
468-
469-
self._count += 1
470-
except:
471-
# rollback
472-
self.pk.length = len_before
473-
raise
453+
with self._lock:
454+
len_before = self.pk.length
455+
size_before = self.size
456+
try:
457+
ret = self._pack_trace(trace)
458+
if ret: # should not happen.
459+
raise RuntimeError("internal error")
460+
461+
# DEV: msgpack avoids buffer overflows by calling PyMem_Realloc so
462+
# we must check sizes manually.
463+
# TODO: We should probably ensure that the buffer size doesn't
464+
# grow arbitrarily because of the PyMem_Realloc and if it does then
465+
# free and reallocate with the appropriate size.
466+
if self.size - size_before > self.max_item_size:
467+
raise BufferItemTooLarge(self.size - size_before)
468+
469+
if self.size > self.max_size:
470+
raise BufferFull(self.size - size_before)
471+
472+
self._count += 1
473+
except:
474+
# rollback
475+
self.pk.length = len_before
476+
raise
474477

475478
@property
476479
def size(self):
477480
"""Return the size in bytes of the encoder buffer."""
478-
return self.pk.length + array_prefix_size(self._count) - MSGPACK_ARRAY_LENGTH_PREFIX_SIZE
481+
with self._lock:
482+
return self.pk.length + array_prefix_size(self._count) - MSGPACK_ARRAY_LENGTH_PREFIX_SIZE
479483

480484
# ---- Abstract methods ----
481485

@@ -488,10 +492,11 @@ cdef class MsgpackEncoderBase(BufferedEncoder):
488492

489493
cdef class MsgpackEncoderV03(MsgpackEncoderBase):
490494
cpdef flush(self):
491-
try:
492-
return self.get_bytes()
493-
finally:
494-
self._reset_buffer()
495+
with self._lock:
496+
try:
497+
return self.get_bytes()
498+
finally:
499+
self._reset_buffer()
495500

496501
cdef void * get_dd_origin_ref(self, str dd_origin):
497502
return string_to_buff(dd_origin)
@@ -637,24 +642,27 @@ cdef class MsgpackEncoderV05(MsgpackEncoderBase):
637642
self._st = MsgpackStringTable(max_size)
638643

639644
cpdef flush(self):
640-
try:
641-
self._st.append_raw(PyLong_FromLong(<long> self.get_buffer()), <Py_ssize_t> super(MsgpackEncoderV05, self).size)
642-
return self._st.flush()
643-
finally:
644-
self._reset_buffer()
645+
with self._lock:
646+
try:
647+
self._st.append_raw(PyLong_FromLong(<long> self.get_buffer()), <Py_ssize_t> super(MsgpackEncoderV05, self).size)
648+
return self._st.flush()
649+
finally:
650+
self._reset_buffer()
645651

646652
@property
647653
def size(self):
648654
"""Return the size in bytes of the encoder buffer."""
649-
return self._st.size + super(MsgpackEncoderV05, self).size
655+
with self._lock:
656+
return self._st.size + super(MsgpackEncoderV05, self).size
650657

651658
cpdef put(self, list trace):
652-
try:
653-
self._st.savepoint()
654-
super(MsgpackEncoderV05, self).put(trace)
655-
except Exception:
656-
self._st.rollback()
657-
raise
659+
with self._lock:
660+
try:
661+
self._st.savepoint()
662+
super(MsgpackEncoderV05, self).put(trace)
663+
except Exception:
664+
self._st.rollback()
665+
raise
658666

659667
cdef inline int _pack_string(self, object string):
660668
return msgpack_pack_uint32(&self.pk, self._st._index(string))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
The thread safety of the custom buffered encoder was fixed in order to eliminate a potential cause of decoding errors of trace payloads (missing trace data) in the agent.

tests/tracer/test_encoders.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import random
44
import string
5+
import threading
56
from unittest import TestCase
67

78
from hypothesis import given
@@ -595,3 +596,35 @@ def test_encoding_invalid_data(data):
595596
encoder.put(trace)
596597

597598
assert encoder.encode() is None
599+
600+
601+
@allencodings
602+
def test_custom_msgpack_encode_thread_safe(encoding):
603+
class TracingThread(threading.Thread):
604+
def __init__(self, encoder, span_count, trace_count):
605+
super(TracingThread, self).__init__()
606+
trace = [
607+
Span(tracer=None, name="span-{}-{}".format(self.name, _), service="threads", resource="TEST")
608+
for _ in range(span_count)
609+
]
610+
self._encoder = encoder
611+
self._trace = trace
612+
self._trace_count = trace_count
613+
614+
def run(self):
615+
for _ in range(self._trace_count):
616+
self._encoder.put(self._trace)
617+
618+
THREADS = 40
619+
SPANS = 15
620+
TRACES = 10
621+
encoder = MSGPACK_ENCODERS[encoding](2 << 20, 2 << 20)
622+
623+
ts = [TracingThread(encoder, random.randint(1, SPANS), random.randint(1, TRACES)) for _ in range(THREADS)]
624+
for t in ts:
625+
t.start()
626+
for t in ts:
627+
t.join()
628+
629+
unpacked = decode(encoder.encode(), reconstruct=True)
630+
assert unpacked is not None

0 commit comments

Comments
 (0)