@@ -403,29 +403,41 @@ cdef stack_collect(ignore_profiler, thread_time, max_nframes, interval, wall_tim
403403 return stack_events, exc_events
404404
405405
406+ # cython does not play well with mypy
407+ if typing.TYPE_CHECKING:
408+ _T = typing.TypeVar(" _T" )
409+ _thread_link_base = typing.Generic[_T]
410+ _weakref_type = weakref.ReferenceType[_T]
411+ else :
412+ _thread_link_base = object
413+ _weakref_type = typing.Any
414+
415+
406416@ attr.s (slots = True , eq = False )
407- class _ThreadSpanLinks (object ):
417+ class _ThreadLink (_thread_link_base ):
418+ """ Link a thread with an object.
419+
420+ Object is removed when the thread disappears.
421+ """
408422
409423 # Key is a thread_id
410- # Value is a weakref to latest active span
411- _thread_id_to_spans = attr.ib(factory = dict , repr = False , init = False , type = typing.Dict[int , ddspan.Span ])
424+ # Value is a weakref to an object
425+ _thread_id_to_object = attr.ib(factory = dict , repr = False , init = False , type = typing.Dict[int , _weakref_type ])
412426 _lock = attr.ib(factory = nogevent.Lock, repr = False , init = False , type = nogevent.Lock)
413427
414- def link_span (
428+ def link_object (
415429 self ,
416- span # type: typing.Optional[typing.Union[context. Context , ddspan.Span]]
430+ obj # type: _T
417431 ):
418432 # type: (...) -> None
419- """ Link a span to its running environment.
420-
421- Track threads, tasks, etc.
422- """
433+ """ Link an object to the current running thread."""
423434 # Since we're going to iterate over the set, make sure it's locked
424- if isinstance (span, ddspan.Span):
425- with self ._lock:
426- self ._thread_id_to_spans[nogevent.thread_get_ident()] = weakref.ref(span)
435+ with self ._lock:
436+ self ._thread_id_to_object[nogevent.thread_get_ident()] = weakref.ref(obj)
427437
428- def clear_threads (self , existing_thread_ids ):
438+ def clear_threads (self ,
439+ existing_thread_ids , # type: typing.Set[int]
440+ ):
429441 """ Clear the stored list of threads based on the list of existing thread ids.
430442
431443 If any thread that is part of this list was stored, its data will be deleted.
@@ -434,9 +446,48 @@ class _ThreadSpanLinks(object):
434446 """
435447 with self ._lock:
436448 # Iterate over a copy of the list of keys since it's mutated during our iteration.
437- for thread_id in list (self ._thread_id_to_spans .keys()):
449+ for thread_id in list (self ._thread_id_to_object .keys()):
438450 if thread_id not in existing_thread_ids:
439- del self ._thread_id_to_spans[thread_id]
451+ del self ._thread_id_to_object[thread_id]
452+
453+ def get_object (
454+ self ,
455+ thread_id # type: int
456+ ):
457+ # type: (...) -> _T
458+ """ Return the object attached to thread.
459+
460+ :param thread_id: The thread id.
461+ :return: The attached object.
462+ """
463+
464+ with self ._lock:
465+ obj_ref = self ._thread_id_to_object.get(thread_id)
466+ if obj_ref is not None :
467+ return obj_ref()
468+
469+
470+ if typing.TYPE_CHECKING:
471+ _thread_span_links_base = _ThreadLink[ddspan.Span]
472+ else :
473+ _thread_span_links_base = _ThreadLink
474+
475+
476+ @ attr.s (slots = True , eq = False )
477+ class _ThreadSpanLinks (_thread_span_links_base ):
478+
479+ def link_span (
480+ self ,
481+ span # type: typing.Optional[typing.Union[context.Context , ddspan.Span]]
482+ ):
483+ # type: (...) -> None
484+ """ Link a span to its running environment.
485+
486+ Track threads, tasks, etc.
487+ """
488+ # Since we're going to iterate over the set, make sure it's locked
489+ if isinstance (span, ddspan.Span):
490+ self .link_object(span)
440491
441492 def get_active_span_from_thread_id (
442493 self ,
@@ -448,14 +499,10 @@ class _ThreadSpanLinks(object):
448499 :param thread_id: The thread id.
449500 :return: A set with the active spans.
450501 """
451-
452- with self ._lock:
453- active_span_ref = self ._thread_id_to_spans.get(thread_id)
454- if active_span_ref is not None :
455- active_span = active_span_ref()
456- if active_span is not None and not active_span.finished:
457- return active_span
458- return None
502+ active_span = self .get_object(thread_id)
503+ if active_span is not None and not active_span.finished:
504+ return active_span
505+ return None
459506
460507
461508def _default_min_interval_time ():
0 commit comments