|
| 1 | +import inspect |
| 2 | +from typing import Any |
| 3 | +from typing import Dict |
| 4 | +from typing import List |
| 5 | +from typing import Optional |
| 6 | +from typing import Tuple |
| 7 | +from typing import TypeVar |
| 8 | +import weakref |
| 9 | + |
| 10 | +import dd_trace_api |
| 11 | + |
| 12 | +import ddtrace |
| 13 | +from ddtrace.internal.logger import get_logger |
| 14 | +from ddtrace.internal.wrapping.context import WrappingContext |
| 15 | + |
| 16 | + |
| 17 | +_DD_HOOK_NAME = "dd.hook" |
| 18 | +_TRACER_KEY = "Tracer" |
| 19 | +_STUB_TO_REAL = weakref.WeakKeyDictionary() |
| 20 | +_STUB_TO_REAL[dd_trace_api.tracer] = ddtrace.tracer |
| 21 | +log = get_logger(__name__) |
| 22 | +T = TypeVar("T") |
| 23 | +_FN_PARAMS: Dict[str, List[str]] = dict() |
| 24 | + |
| 25 | + |
| 26 | +def _params_for_fn(wrapping_context: WrappingContext, instance: dd_trace_api._Stub, fn_name: str): |
| 27 | + key = f"{instance.__class__.__name__}.{fn_name}" |
| 28 | + if key not in _FN_PARAMS: |
| 29 | + _FN_PARAMS[key] = list(inspect.signature(wrapping_context.__wrapped__).parameters.keys()) |
| 30 | + return _FN_PARAMS[key] |
| 31 | + |
| 32 | + |
| 33 | +class DDTraceAPIWrappingContextBase(WrappingContext): |
| 34 | + def _handle_return(self) -> None: |
| 35 | + stub = self.get_local("self") |
| 36 | + fn_name = self.__frame__.f_code.co_name |
| 37 | + _call_on_real_instance( |
| 38 | + stub, |
| 39 | + fn_name, |
| 40 | + self.get_local("retval"), |
| 41 | + **{param: self.get_local(param) for param in _params_for_fn(self, stub, fn_name) if param != "self"}, |
| 42 | + ) |
| 43 | + |
| 44 | + def __return__(self, value: T) -> T: |
| 45 | + """Always return the original value no matter what our instrumentation does""" |
| 46 | + try: |
| 47 | + self._handle_return() |
| 48 | + except Exception: # noqa: E722 |
| 49 | + log.debug("Error handling instrumentation return", exc_info=True) |
| 50 | + |
| 51 | + return value |
| 52 | + |
| 53 | + |
| 54 | +def _proxy_span_arguments(args: List, kwargs: Dict) -> Tuple[List, Dict]: |
| 55 | + """Convert all dd_trace_api.Span objects in the args/kwargs collections to their held ddtrace.Span objects""" |
| 56 | + |
| 57 | + def convert(arg): |
| 58 | + return _STUB_TO_REAL[arg] if isinstance(arg, dd_trace_api.Span) else arg |
| 59 | + |
| 60 | + return [convert(arg) for arg in args], {name: convert(kwarg) for name, kwarg in kwargs.items()} |
| 61 | + |
| 62 | + |
| 63 | +def _call_on_real_instance( |
| 64 | + operand_stub: dd_trace_api._Stub, method_name: str, retval_from_api: Optional[Any], *args: List, **kwargs: Dict |
| 65 | +) -> None: |
| 66 | + """ |
| 67 | + Call `method_name` on the real object corresponding to `operand_stub` with `args` and `kwargs` as arguments. |
| 68 | +
|
| 69 | + Store the value that will be returned from the API call we're in the middle of, for the purpose |
| 70 | + of mapping from those Stub objects to their real counterparts. |
| 71 | + """ |
| 72 | + args, kwargs = _proxy_span_arguments(args, kwargs) |
| 73 | + retval_from_impl = getattr(_STUB_TO_REAL[operand_stub], method_name)(*args, **kwargs) |
| 74 | + if retval_from_api is not None: |
| 75 | + _STUB_TO_REAL[retval_from_api] = retval_from_impl |
| 76 | + |
| 77 | + |
| 78 | +def get_version() -> str: |
| 79 | + return getattr(dd_trace_api, "__version__", "") |
| 80 | + |
| 81 | + |
| 82 | +def patch(tracer=None): |
| 83 | + if getattr(dd_trace_api, "__datadog_patch", False): |
| 84 | + return |
| 85 | + _STUB_TO_REAL[dd_trace_api.tracer] = tracer |
| 86 | + |
| 87 | + DDTraceAPIWrappingContextBase(dd_trace_api.Tracer.start_span).wrap() |
| 88 | + DDTraceAPIWrappingContextBase(dd_trace_api.Tracer.trace).wrap() |
| 89 | + DDTraceAPIWrappingContextBase(dd_trace_api.Tracer.current_span).wrap() |
| 90 | + DDTraceAPIWrappingContextBase(dd_trace_api.Tracer.current_root_span).wrap() |
| 91 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.finish).wrap() |
| 92 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.set_exc_info).wrap() |
| 93 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.finish_with_ancestors).wrap() |
| 94 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.set_tags).wrap() |
| 95 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.set_traceback).wrap() |
| 96 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.__enter__).wrap() |
| 97 | + DDTraceAPIWrappingContextBase(dd_trace_api.Span.__exit__).wrap() |
| 98 | + |
| 99 | + dd_trace_api.__datadog_patch = True |
| 100 | + |
| 101 | + |
| 102 | +def unpatch(): |
| 103 | + if not getattr(dd_trace_api, "__datadog_patch", False): |
| 104 | + return |
| 105 | + dd_trace_api.__datadog_patch = False |
| 106 | + |
| 107 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Tracer.start_span).unwrap() |
| 108 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Tracer.trace).unwrap() |
| 109 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Tracer.current_span).unwrap() |
| 110 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Tracer.current_root_span).unwrap() |
| 111 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.finish).unwrap() |
| 112 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.set_exc_info).unwrap() |
| 113 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.finish_with_ancestors).unwrap() |
| 114 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.set_tags).unwrap() |
| 115 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.set_traceback).unwrap() |
| 116 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.__enter__).unwrap() |
| 117 | + DDTraceAPIWrappingContextBase.extract(dd_trace_api.Span.__exit__).unwrap() |
| 118 | + |
| 119 | + dd_trace_api.__datadog_patch = False |
0 commit comments