diff --git a/sentry_sdk/opentelemetry/contextvars_context.py b/sentry_sdk/opentelemetry/contextvars_context.py index abd4c60d3f..34d7866f3c 100644 --- a/sentry_sdk/opentelemetry/contextvars_context.py +++ b/sentry_sdk/opentelemetry/contextvars_context.py @@ -1,4 +1,5 @@ -from typing import cast, TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING from opentelemetry.trace import get_current_span, set_span_in_context from opentelemetry.trace.span import INVALID_SPAN @@ -13,36 +14,37 @@ SENTRY_USE_CURRENT_SCOPE_KEY, SENTRY_USE_ISOLATION_SCOPE_KEY, ) +from sentry_sdk.opentelemetry.scope import PotelScope, validate_scopes if TYPE_CHECKING: - from typing import Optional from contextvars import Token - import sentry_sdk.opentelemetry.scope as scope class SentryContextVarsRuntimeContext(ContextVarsRuntimeContext): - def attach(self, context): - # type: (Context) -> Token[Context] - scopes = get_value(SENTRY_SCOPES_KEY, context) + def attach(self, context: Context) -> Token[Context]: + scopes = validate_scopes(get_value(SENTRY_SCOPES_KEY, context)) - should_fork_isolation_scope = context.pop( - SENTRY_FORK_ISOLATION_SCOPE_KEY, False + should_fork_isolation_scope = bool( + context.pop(SENTRY_FORK_ISOLATION_SCOPE_KEY, False) ) - should_fork_isolation_scope = cast("bool", should_fork_isolation_scope) should_use_isolation_scope = context.pop(SENTRY_USE_ISOLATION_SCOPE_KEY, None) - should_use_isolation_scope = cast( - "Optional[scope.PotelScope]", should_use_isolation_scope + should_use_isolation_scope = ( + should_use_isolation_scope + if isinstance(should_use_isolation_scope, PotelScope) + else None ) should_use_current_scope = context.pop(SENTRY_USE_CURRENT_SCOPE_KEY, None) - should_use_current_scope = cast( - "Optional[scope.PotelScope]", should_use_current_scope + should_use_current_scope = ( + should_use_current_scope + if isinstance(should_use_current_scope, PotelScope) + else None ) if scopes: - scopes = cast("tuple[scope.PotelScope, scope.PotelScope]", scopes) - (current_scope, isolation_scope) = scopes + current_scope = scopes[0] + isolation_scope = scopes[1] else: current_scope = sentry_sdk.get_current_scope() isolation_scope = sentry_sdk.get_isolation_scope() diff --git a/sentry_sdk/opentelemetry/propagator.py b/sentry_sdk/opentelemetry/propagator.py index 16a0d19cc9..f76dcc3906 100644 --- a/sentry_sdk/opentelemetry/propagator.py +++ b/sentry_sdk/opentelemetry/propagator.py @@ -1,4 +1,4 @@ -from typing import cast +from __future__ import annotations from opentelemetry import trace from opentelemetry.context import ( @@ -37,12 +37,12 @@ extract_sentrytrace_data, should_propagate_trace, ) +from sentry_sdk.opentelemetry.scope import validate_scopes from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Set - import sentry_sdk.opentelemetry.scope as scope class SentryPropagator(TextMapPropagator): @@ -50,8 +50,12 @@ class SentryPropagator(TextMapPropagator): Propagates tracing headers for Sentry's tracing system in a way OTel understands. """ - def extract(self, carrier, context=None, getter=default_getter): - # type: (CarrierT, Optional[Context], Getter[CarrierT]) -> Context + def extract( + self, + carrier: CarrierT, + context: Optional[Context] = None, + getter: Getter[CarrierT] = default_getter, + ) -> Context: if context is None: context = get_current() @@ -93,13 +97,16 @@ def extract(self, carrier, context=None, getter=default_getter): modified_context = trace.set_span_in_context(span, context) return modified_context - def inject(self, carrier, context=None, setter=default_setter): - # type: (CarrierT, Optional[Context], Setter[CarrierT]) -> None - scopes = get_value(SENTRY_SCOPES_KEY, context) + def inject( + self, + carrier: CarrierT, + context: Optional[Context] = None, + setter: Setter[CarrierT] = default_setter, + ) -> None: + scopes = validate_scopes(get_value(SENTRY_SCOPES_KEY, context)) if not scopes: return - scopes = cast("tuple[scope.PotelScope, scope.PotelScope]", scopes) (current_scope, _) = scopes span = current_scope.span @@ -114,6 +121,5 @@ def inject(self, carrier, context=None, setter=default_setter): setter.set(carrier, key, value) @property - def fields(self): - # type: () -> Set[str] + def fields(self) -> Set[str]: return {SENTRY_TRACE_HEADER_NAME, BAGGAGE_HEADER_NAME} diff --git a/sentry_sdk/opentelemetry/sampler.py b/sentry_sdk/opentelemetry/sampler.py index ab3defe3de..878b856f5a 100644 --- a/sentry_sdk/opentelemetry/sampler.py +++ b/sentry_sdk/opentelemetry/sampler.py @@ -1,5 +1,5 @@ +from __future__ import annotations from decimal import Decimal -from typing import cast from opentelemetry import trace from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision @@ -21,15 +21,16 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional, Sequence, Union + from typing import Any, Optional, Sequence from opentelemetry.context import Context from opentelemetry.trace import Link, SpanKind from opentelemetry.trace.span import SpanContext from opentelemetry.util.types import Attributes -def get_parent_sampled(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[bool] +def get_parent_sampled( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[bool]: if parent_context is None: return None @@ -54,8 +55,9 @@ def get_parent_sampled(parent_context, trace_id): return None -def get_parent_sample_rate(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[float] +def get_parent_sample_rate( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[float]: if parent_context is None: return None @@ -74,8 +76,9 @@ def get_parent_sample_rate(parent_context, trace_id): return None -def get_parent_sample_rand(parent_context, trace_id): - # type: (Optional[SpanContext], int) -> Optional[Decimal] +def get_parent_sample_rand( + parent_context: Optional[SpanContext], trace_id: int +) -> Optional[Decimal]: if parent_context is None: return None @@ -91,8 +94,12 @@ def get_parent_sample_rand(parent_context, trace_id): return None -def dropped_result(span_context, attributes, sample_rate=None, sample_rand=None): - # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult +def dropped_result( + span_context: SpanContext, + attributes: Attributes, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> SamplingResult: """ React to a span getting unsampled and return a DROP SamplingResult. @@ -129,8 +136,12 @@ def dropped_result(span_context, attributes, sample_rate=None, sample_rand=None) ) -def sampled_result(span_context, attributes, sample_rate=None, sample_rand=None): - # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult +def sampled_result( + span_context: SpanContext, + attributes: Attributes, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> SamplingResult: """ React to a span being sampled and return a sampled SamplingResult. @@ -151,8 +162,12 @@ def sampled_result(span_context, attributes, sample_rate=None, sample_rand=None) ) -def _update_trace_state(span_context, sampled, sample_rate=None, sample_rand=None): - # type: (SpanContext, bool, Optional[float], Optional[Decimal]) -> TraceState +def _update_trace_state( + span_context: SpanContext, + sampled: bool, + sample_rate: Optional[float] = None, + sample_rand: Optional[Decimal] = None, +) -> TraceState: trace_state = span_context.trace_state sampled = "true" if sampled else "false" @@ -175,15 +190,14 @@ def _update_trace_state(span_context, sampled, sample_rate=None, sample_rand=Non class SentrySampler(Sampler): def should_sample( self, - parent_context, # type: Optional[Context] - trace_id, # type: int - name, # type: str - kind=None, # type: Optional[SpanKind] - attributes=None, # type: Attributes - links=None, # type: Optional[Sequence[Link]] - trace_state=None, # type: Optional[TraceState] - ): - # type: (...) -> SamplingResult + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: Optional[SpanKind] = None, + attributes: Attributes = None, + links: Optional[Sequence[Link]] = None, + trace_state: Optional[TraceState] = None, + ) -> SamplingResult: client = sentry_sdk.get_client() parent_span_context = trace.get_current_span(parent_context).get_span_context() @@ -209,13 +223,12 @@ def should_sample( sample_rand = parent_sample_rand else: # We are the head SDK and we need to generate a new sample_rand - sample_rand = cast(Decimal, _generate_sample_rand(str(trace_id), (0, 1))) + sample_rand = _generate_sample_rand(str(trace_id), (0, 1)) # Explicit sampled value provided at start_span - custom_sampled = cast( - "Optional[bool]", attributes.get(SentrySpanAttribute.CUSTOM_SAMPLED) - ) - if custom_sampled is not None: + custom_sampled = attributes.get(SentrySpanAttribute.CUSTOM_SAMPLED) + + if custom_sampled is not None and isinstance(custom_sampled, bool): if is_root_span: sample_rate = float(custom_sampled) if sample_rate > 0: @@ -262,7 +275,8 @@ def should_sample( sample_rate_to_propagate = sample_rate # If the sample rate is invalid, drop the span - if not is_valid_sample_rate(sample_rate, source=self.__class__.__name__): + sample_rate = is_valid_sample_rate(sample_rate, source=self.__class__.__name__) + if sample_rate is None: logger.warning( f"[Tracing.Sampler] Discarding {name} because of invalid sample rate." ) @@ -275,7 +289,6 @@ def should_sample( sample_rate_to_propagate = sample_rate # Compare sample_rand to sample_rate to make the final sampling decision - sample_rate = float(cast("Union[bool, float, int]", sample_rate)) sampled = sample_rand < Decimal.from_float(sample_rate) if sampled: @@ -307,9 +320,13 @@ def get_description(self) -> str: return self.__class__.__name__ -def create_sampling_context(name, attributes, parent_span_context, trace_id): - # type: (str, Attributes, Optional[SpanContext], int) -> dict[str, Any] - sampling_context = { +def create_sampling_context( + name: str, + attributes: Attributes, + parent_span_context: Optional[SpanContext], + trace_id: int, +) -> dict[str, Any]: + sampling_context: dict[str, Any] = { "transaction_context": { "name": name, "op": attributes.get(SentrySpanAttribute.OP) if attributes else None, @@ -318,7 +335,7 @@ def create_sampling_context(name, attributes, parent_span_context, trace_id): ), }, "parent_sampled": get_parent_sampled(parent_span_context, trace_id), - } # type: dict[str, Any] + } if attributes is not None: sampling_context.update(attributes) diff --git a/sentry_sdk/opentelemetry/scope.py b/sentry_sdk/opentelemetry/scope.py index 4db5e288e3..ec398093c7 100644 --- a/sentry_sdk/opentelemetry/scope.py +++ b/sentry_sdk/opentelemetry/scope.py @@ -1,4 +1,4 @@ -from typing import cast +from __future__ import annotations from contextlib import contextmanager import warnings @@ -24,9 +24,6 @@ SENTRY_USE_ISOLATION_SCOPE_KEY, TRACESTATE_SAMPLED_KEY, ) -from sentry_sdk.opentelemetry.contextvars_context import ( - SentryContextVarsRuntimeContext, -) from sentry_sdk.opentelemetry.utils import trace_state_from_baggage from sentry_sdk.scope import Scope, ScopeType from sentry_sdk.tracing import Span @@ -38,26 +35,21 @@ class PotelScope(Scope): @classmethod - def _get_scopes(cls): - # type: () -> Optional[Tuple[PotelScope, PotelScope]] + def _get_scopes(cls) -> Optional[Tuple[PotelScope, PotelScope]]: """ Returns the current scopes tuple on the otel context. Internal use only. """ - return cast( - "Optional[Tuple[PotelScope, PotelScope]]", get_value(SENTRY_SCOPES_KEY) - ) + return validate_scopes(get_value(SENTRY_SCOPES_KEY)) @classmethod - def get_current_scope(cls): - # type: () -> PotelScope + def get_current_scope(cls) -> PotelScope: """ Returns the current scope. """ return cls._get_current_scope() or _INITIAL_CURRENT_SCOPE @classmethod - def _get_current_scope(cls): - # type: () -> Optional[PotelScope] + def _get_current_scope(cls) -> Optional[PotelScope]: """ Returns the current scope without creating a new one. Internal use only. """ @@ -65,16 +57,14 @@ def _get_current_scope(cls): return scopes[0] if scopes else None @classmethod - def get_isolation_scope(cls): - # type: () -> PotelScope + def get_isolation_scope(cls) -> PotelScope: """ Returns the isolation scope. """ return cls._get_isolation_scope() or _INITIAL_ISOLATION_SCOPE @classmethod - def _get_isolation_scope(cls): - # type: () -> Optional[PotelScope] + def _get_isolation_scope(cls) -> Optional[PotelScope]: """ Returns the isolation scope without creating a new one. Internal use only. """ @@ -82,8 +72,9 @@ def _get_isolation_scope(cls): return scopes[1] if scopes else None @contextmanager - def continue_trace(self, environ_or_headers): - # type: (Dict[str, Any]) -> Generator[None, None, None] + def continue_trace( + self, environ_or_headers: Dict[str, Any] + ) -> Generator[None, None, None]: """ Sets the propagation context from environment or headers to continue an incoming trace. Any span started within this context manager will use the same trace_id, parent_span_id @@ -98,8 +89,7 @@ def continue_trace(self, environ_or_headers): with use_span(NonRecordingSpan(span_context)): yield - def _incoming_otel_span_context(self): - # type: () -> Optional[SpanContext] + def _incoming_otel_span_context(self) -> Optional[SpanContext]: if self._propagation_context is None: return None # If sentry-trace extraction didn't have a parent_span_id, we don't have an upstream header @@ -132,8 +122,7 @@ def _incoming_otel_span_context(self): return span_context - def start_transaction(self, **kwargs): - # type: (Any) -> Span + def start_transaction(self, **kwargs: Any) -> Span: """ .. deprecated:: 3.0.0 This function is deprecated and will be removed in a future release. @@ -146,8 +135,7 @@ def start_transaction(self, **kwargs): ) return self.start_span(**kwargs) - def start_span(self, **kwargs): - # type: (Any) -> Span + def start_span(self, **kwargs: Any) -> Span: return Span(**kwargs) @@ -155,8 +143,7 @@ def start_span(self, **kwargs): _INITIAL_ISOLATION_SCOPE = PotelScope(ty=ScopeType.ISOLATION) -def setup_initial_scopes(): - # type: () -> None +def setup_initial_scopes() -> None: global _INITIAL_CURRENT_SCOPE, _INITIAL_ISOLATION_SCOPE _INITIAL_CURRENT_SCOPE = PotelScope(ty=ScopeType.CURRENT) _INITIAL_ISOLATION_SCOPE = PotelScope(ty=ScopeType.ISOLATION) @@ -165,17 +152,18 @@ def setup_initial_scopes(): attach(set_value(SENTRY_SCOPES_KEY, scopes)) -def setup_scope_context_management(): - # type: () -> None +def setup_scope_context_management() -> None: import opentelemetry.context + from sentry_sdk.opentelemetry.contextvars_context import ( + SentryContextVarsRuntimeContext, + ) opentelemetry.context._RUNTIME_CONTEXT = SentryContextVarsRuntimeContext() setup_initial_scopes() @contextmanager -def isolation_scope(): - # type: () -> Generator[PotelScope, None, None] +def isolation_scope() -> Generator[PotelScope, None, None]: context = set_value(SENTRY_FORK_ISOLATION_SCOPE_KEY, True) token = attach(context) try: @@ -185,8 +173,7 @@ def isolation_scope(): @contextmanager -def new_scope(): - # type: () -> Generator[PotelScope, None, None] +def new_scope() -> Generator[PotelScope, None, None]: token = attach(get_current()) try: yield PotelScope.get_current_scope() @@ -195,8 +182,7 @@ def new_scope(): @contextmanager -def use_scope(scope): - # type: (PotelScope) -> Generator[PotelScope, None, None] +def use_scope(scope: PotelScope) -> Generator[PotelScope, None, None]: context = set_value(SENTRY_USE_CURRENT_SCOPE_KEY, scope) token = attach(context) @@ -207,8 +193,9 @@ def use_scope(scope): @contextmanager -def use_isolation_scope(isolation_scope): - # type: (PotelScope) -> Generator[PotelScope, None, None] +def use_isolation_scope( + isolation_scope: PotelScope, +) -> Generator[PotelScope, None, None]: context = set_value(SENTRY_USE_ISOLATION_SCOPE_KEY, isolation_scope) token = attach(context) @@ -216,3 +203,15 @@ def use_isolation_scope(isolation_scope): yield isolation_scope finally: detach(token) + + +def validate_scopes(scopes: Any) -> Optional[Tuple[PotelScope, PotelScope]]: + if ( + isinstance(scopes, tuple) + and len(scopes) == 2 + and isinstance(scopes[0], PotelScope) + and isinstance(scopes[1], PotelScope) + ): + return scopes + else: + return None diff --git a/sentry_sdk/opentelemetry/span_processor.py b/sentry_sdk/opentelemetry/span_processor.py index a148fb0f62..f35af99920 100644 --- a/sentry_sdk/opentelemetry/span_processor.py +++ b/sentry_sdk/opentelemetry/span_processor.py @@ -1,5 +1,5 @@ +from __future__ import annotations from collections import deque, defaultdict -from typing import cast from opentelemetry.trace import ( format_trace_id, @@ -52,30 +52,24 @@ class SentrySpanProcessor(SpanProcessor): Converts OTel spans into Sentry spans so they can be sent to the Sentry backend. """ - def __new__(cls): - # type: () -> SentrySpanProcessor + def __new__(cls) -> SentrySpanProcessor: if not hasattr(cls, "instance"): cls.instance = super().__new__(cls) return cls.instance - def __init__(self): - # type: () -> None - self._children_spans = defaultdict( - list - ) # type: DefaultDict[int, List[ReadableSpan]] - self._dropped_spans = defaultdict(lambda: 0) # type: DefaultDict[int, int] + def __init__(self) -> None: + self._children_spans: DefaultDict[int, List[ReadableSpan]] = defaultdict(list) + self._dropped_spans: DefaultDict[int, int] = defaultdict(lambda: 0) - def on_start(self, span, parent_context=None): - # type: (Span, Optional[Context]) -> None + def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: if is_sentry_span(span): return self._add_root_span(span, get_current_span(parent_context)) self._start_profile(span) - def on_end(self, span): - # type: (ReadableSpan) -> None + def on_end(self, span: ReadableSpan) -> None: if is_sentry_span(span): return @@ -88,18 +82,15 @@ def on_end(self, span): self._append_child_span(span) # TODO-neel-potel not sure we need a clear like JS - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: pass # TODO-neel-potel change default? this is 30 sec # TODO-neel-potel call this in client.flush - def force_flush(self, timeout_millis=30000): - # type: (int) -> bool + def force_flush(self, timeout_millis: int = 30000) -> bool: return True - def _add_root_span(self, span, parent_span): - # type: (Span, AbstractSpan) -> None + def _add_root_span(self, span: Span, parent_span: AbstractSpan) -> None: """ This is required to make Span.root_span work since we can't traverse back to the root purely with otel efficiently. @@ -112,8 +103,7 @@ def _add_root_span(self, span, parent_span): # root span points to itself set_sentry_meta(span, "root_span", span) - def _start_profile(self, span): - # type: (Span) -> None + def _start_profile(self, span: Span) -> None: try_autostart_continuous_profiler() profiler_id = get_profiler_id() @@ -148,14 +138,12 @@ def _start_profile(self, span): span.set_attribute(SPANDATA.PROFILER_ID, profiler_id) set_sentry_meta(span, "continuous_profile", continuous_profile) - def _stop_profile(self, span): - # type: (ReadableSpan) -> None + def _stop_profile(self, span: ReadableSpan) -> None: continuous_profiler = get_sentry_meta(span, "continuous_profile") if continuous_profiler: continuous_profiler.stop() - def _flush_root_span(self, span): - # type: (ReadableSpan) -> None + def _flush_root_span(self, span: ReadableSpan) -> None: transaction_event = self._root_span_to_transaction_event(span) if not transaction_event: return @@ -176,8 +164,7 @@ def _flush_root_span(self, span): sentry_sdk.capture_event(transaction_event) self._cleanup_references([span] + collected_spans) - def _append_child_span(self, span): - # type: (ReadableSpan) -> None + def _append_child_span(self, span: ReadableSpan) -> None: if not span.parent: return @@ -192,14 +179,13 @@ def _append_child_span(self, span): else: self._dropped_spans[span.parent.span_id] += 1 - def _collect_children(self, span): - # type: (ReadableSpan) -> tuple[List[ReadableSpan], int] + def _collect_children(self, span: ReadableSpan) -> tuple[List[ReadableSpan], int]: if not span.context: return [], 0 children = [] dropped_spans = 0 - bfs_queue = deque() # type: Deque[int] + bfs_queue: Deque[int] = deque() bfs_queue.append(span.context.span_id) while bfs_queue: @@ -215,8 +201,7 @@ def _collect_children(self, span): # we construct the event from scratch here # and not use the current Transaction class for easier refactoring - def _root_span_to_transaction_event(self, span): - # type: (ReadableSpan) -> Optional[Event] + def _root_span_to_transaction_event(self, span: ReadableSpan) -> Optional[Event]: if not span.context: return None @@ -250,23 +235,20 @@ def _root_span_to_transaction_event(self, span): } ) - profile = cast("Optional[Profile]", get_sentry_meta(span, "profile")) - if profile: + profile = get_sentry_meta(span, "profile") + if profile is not None and isinstance(profile, Profile): profile.__exit__(None, None, None) if profile.valid(): event["profile"] = profile return event - def _span_to_json(self, span): - # type: (ReadableSpan) -> Optional[dict[str, Any]] + def _span_to_json(self, span: ReadableSpan) -> Optional[dict[str, Any]]: if not span.context: return None - # This is a safe cast because dict[str, Any] is a superset of Event - span_json = cast( - "dict[str, Any]", self._common_span_transaction_attributes_as_json(span) - ) + # need to ignore the type here due to TypedDict nonsense + span_json: Optional[dict[str, Any]] = self._common_span_transaction_attributes_as_json(span) # type: ignore if span_json is None: return None @@ -299,15 +281,16 @@ def _span_to_json(self, span): return span_json - def _common_span_transaction_attributes_as_json(self, span): - # type: (ReadableSpan) -> Optional[Event] + def _common_span_transaction_attributes_as_json( + self, span: ReadableSpan + ) -> Optional[Event]: if not span.start_time or not span.end_time: return None - common_json = { + common_json: Event = { "start_timestamp": convert_from_otel_timestamp(span.start_time), "timestamp": convert_from_otel_timestamp(span.end_time), - } # type: Event + } tags = extract_span_attributes(span, SentrySpanAttribute.TAG) if tags: @@ -315,13 +298,11 @@ def _common_span_transaction_attributes_as_json(self, span): return common_json - def _cleanup_references(self, spans): - # type: (List[ReadableSpan]) -> None + def _cleanup_references(self, spans: List[ReadableSpan]) -> None: for span in spans: delete_sentry_meta(span) - def _log_debug_info(self): - # type: () -> None + def _log_debug_info(self) -> None: import pprint pprint.pprint( diff --git a/sentry_sdk/opentelemetry/tracing.py b/sentry_sdk/opentelemetry/tracing.py index 5002f71c50..a736a4a477 100644 --- a/sentry_sdk/opentelemetry/tracing.py +++ b/sentry_sdk/opentelemetry/tracing.py @@ -1,3 +1,4 @@ +from __future__ import annotations from opentelemetry import trace from opentelemetry.propagate import set_global_textmap from opentelemetry.sdk.trace import TracerProvider, Span, ReadableSpan @@ -10,16 +11,14 @@ from sentry_sdk.utils import logger -def patch_readable_span(): - # type: () -> None +def patch_readable_span() -> None: """ We need to pass through sentry specific metadata/objects from Span to ReadableSpan to work with them consistently in the SpanProcessor. """ old_readable_span = Span._readable_span - def sentry_patched_readable_span(self): - # type: (Span) -> ReadableSpan + def sentry_patched_readable_span(self: Span) -> ReadableSpan: readable_span = old_readable_span(self) readable_span._sentry_meta = getattr(self, "_sentry_meta", {}) # type: ignore[attr-defined] return readable_span @@ -27,8 +26,7 @@ def sentry_patched_readable_span(self): Span._readable_span = sentry_patched_readable_span # type: ignore[method-assign] -def setup_sentry_tracing(): - # type: () -> None +def setup_sentry_tracing() -> None: # TracerProvider can only be set once. If we're the first ones setting it, # there's no issue. If it already exists, we need to patch it. from opentelemetry.trace import _TRACER_PROVIDER diff --git a/sentry_sdk/opentelemetry/utils.py b/sentry_sdk/opentelemetry/utils.py index abee007a6b..114b1dfd36 100644 --- a/sentry_sdk/opentelemetry/utils.py +++ b/sentry_sdk/opentelemetry/utils.py @@ -1,5 +1,5 @@ +from __future__ import annotations import re -from typing import cast from datetime import datetime, timezone from urllib3.util import parse_url as urlparse @@ -30,9 +30,11 @@ from sentry_sdk._types import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional, Mapping, Sequence, Union + from typing import Any, Optional, Mapping, Sequence, Union, Type, TypeVar from sentry_sdk._types import OtelExtractedSpanData + T = TypeVar("T") + GRPC_ERROR_MAP = { "1": SPANSTATUS.CANCELLED, @@ -54,8 +56,7 @@ } -def is_sentry_span(span): - # type: (ReadableSpan) -> bool +def is_sentry_span(span: ReadableSpan) -> bool: """ Break infinite loop: HTTP requests to Sentry are caught by OTel and send again to Sentry. @@ -65,10 +66,8 @@ def is_sentry_span(span): if not span.attributes: return False - span_url = span.attributes.get(SpanAttributes.HTTP_URL, None) - span_url = cast("Optional[str]", span_url) - - if not span_url: + span_url = get_typed_attribute(span.attributes, SpanAttributes.HTTP_URL, str) + if span_url is None: return False dsn_url = None @@ -89,32 +88,30 @@ def is_sentry_span(span): return False -def convert_from_otel_timestamp(time): - # type: (int) -> datetime +def convert_from_otel_timestamp(time: int) -> datetime: """Convert an OTel nanosecond-level timestamp to a datetime.""" return datetime.fromtimestamp(time / 1e9, timezone.utc) -def convert_to_otel_timestamp(time): - # type: (Union[datetime, float]) -> int +def convert_to_otel_timestamp(time: Union[datetime, float]) -> int: """Convert a datetime to an OTel timestamp (with nanosecond precision).""" if isinstance(time, datetime): return int(time.timestamp() * 1e9) return int(time * 1e9) -def extract_transaction_name_source(span): - # type: (ReadableSpan) -> tuple[Optional[str], Optional[str]] +def extract_transaction_name_source( + span: ReadableSpan, +) -> tuple[Optional[str], Optional[str]]: if not span.attributes: return (None, None) return ( - cast("Optional[str]", span.attributes.get(SentrySpanAttribute.NAME)), - cast("Optional[str]", span.attributes.get(SentrySpanAttribute.SOURCE)), + get_typed_attribute(span.attributes, SentrySpanAttribute.NAME, str), + get_typed_attribute(span.attributes, SentrySpanAttribute.SOURCE, str), ) -def extract_span_data(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def extract_span_data(span: ReadableSpan) -> OtelExtractedSpanData: op = span.name description = span.name status, http_status = extract_span_status(span) @@ -122,15 +119,15 @@ def extract_span_data(span): if span.attributes is None: return (op, description, status, http_status, origin) - attribute_op = cast("Optional[str]", span.attributes.get(SentrySpanAttribute.OP)) + attribute_op = get_typed_attribute(span.attributes, SentrySpanAttribute.OP, str) op = attribute_op or op - description = cast( - "str", span.attributes.get(SentrySpanAttribute.DESCRIPTION) or description + description = ( + get_typed_attribute(span.attributes, SentrySpanAttribute.DESCRIPTION, str) + or description ) - origin = cast("Optional[str]", span.attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span.attributes, SentrySpanAttribute.ORIGIN, str) - http_method = span.attributes.get(SpanAttributes.HTTP_METHOD) - http_method = cast("Optional[str]", http_method) + http_method = get_typed_attribute(span.attributes, SpanAttributes.HTTP_METHOD, str) if http_method: return span_data_for_http_method(span) @@ -165,11 +162,10 @@ def extract_span_data(span): return (op, description, status, http_status, origin) -def span_data_for_http_method(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def span_data_for_http_method(span: ReadableSpan) -> OtelExtractedSpanData: span_attributes = span.attributes or {} - op = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.OP)) + op = get_typed_attribute(span_attributes, SentrySpanAttribute.OP, str) if op is None: op = "http" @@ -184,10 +180,9 @@ def span_data_for_http_method(span): peer_name = span_attributes.get(SpanAttributes.NET_PEER_NAME) # TODO-neel-potel remove description completely - description = span_attributes.get( - SentrySpanAttribute.DESCRIPTION - ) or span_attributes.get(SentrySpanAttribute.NAME) - description = cast("Optional[str]", description) + description = get_typed_attribute( + span_attributes, SentrySpanAttribute.DESCRIPTION, str + ) or get_typed_attribute(span_attributes, SentrySpanAttribute.NAME, str) if description is None: description = f"{http_method}" @@ -199,7 +194,7 @@ def span_data_for_http_method(span): description = f"{http_method} {peer_name}" else: url = span_attributes.get(SpanAttributes.HTTP_URL) - url = cast("Optional[str]", url) + url = get_typed_attribute(span_attributes, SpanAttributes.HTTP_URL, str) if url: parsed_url = urlparse(url) @@ -210,28 +205,24 @@ def span_data_for_http_method(span): status, http_status = extract_span_status(span) - origin = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span_attributes, SentrySpanAttribute.ORIGIN, str) return (op, description, status, http_status, origin) -def span_data_for_db_query(span): - # type: (ReadableSpan) -> OtelExtractedSpanData +def span_data_for_db_query(span: ReadableSpan) -> OtelExtractedSpanData: span_attributes = span.attributes or {} - op = cast("str", span_attributes.get(SentrySpanAttribute.OP, OP.DB)) - - statement = span_attributes.get(SpanAttributes.DB_STATEMENT, None) - statement = cast("Optional[str]", statement) + op = get_typed_attribute(span_attributes, SentrySpanAttribute.OP, str) or OP.DB + statement = get_typed_attribute(span_attributes, SpanAttributes.DB_STATEMENT, str) description = statement or span.name - origin = cast("Optional[str]", span_attributes.get(SentrySpanAttribute.ORIGIN)) + origin = get_typed_attribute(span_attributes, SentrySpanAttribute.ORIGIN, str) return (op, description, None, None, origin) -def extract_span_status(span): - # type: (ReadableSpan) -> tuple[Optional[str], Optional[int]] +def extract_span_status(span: ReadableSpan) -> tuple[Optional[str], Optional[int]]: span_attributes = span.attributes or {} status = span.status or None @@ -266,8 +257,19 @@ def extract_span_status(span): return (SPANSTATUS.UNKNOWN_ERROR, None) -def infer_status_from_attributes(span_attributes): - # type: (Mapping[str, str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]]) -> tuple[Optional[str], Optional[int]] +def infer_status_from_attributes( + span_attributes: Mapping[ + str, + str + | bool + | int + | float + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float], + ], +) -> tuple[Optional[str], Optional[int]]: http_status = get_http_status_code(span_attributes) if http_status: @@ -280,10 +282,23 @@ def infer_status_from_attributes(span_attributes): return (None, None) -def get_http_status_code(span_attributes): - # type: (Mapping[str, str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]]) -> Optional[int] +def get_http_status_code( + span_attributes: Mapping[ + str, + str + | bool + | int + | float + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float], + ], +) -> Optional[int]: try: - http_status = span_attributes.get(SpanAttributes.HTTP_RESPONSE_STATUS_CODE) + http_status = get_typed_attribute( + span_attributes, SpanAttributes.HTTP_RESPONSE_STATUS_CODE, int + ) except AttributeError: # HTTP_RESPONSE_STATUS_CODE was added in 1.21, so if we're on an older # OTel version SpanAttributes.HTTP_RESPONSE_STATUS_CODE will throw an @@ -292,19 +307,18 @@ def get_http_status_code(span_attributes): if http_status is None: # Fall back to the deprecated attribute - http_status = span_attributes.get(SpanAttributes.HTTP_STATUS_CODE) - - http_status = cast("Optional[int]", http_status) + http_status = get_typed_attribute( + span_attributes, SpanAttributes.HTTP_STATUS_CODE, int + ) return http_status -def extract_span_attributes(span, namespace): - # type: (ReadableSpan, str) -> dict[str, Any] +def extract_span_attributes(span: ReadableSpan, namespace: str) -> dict[str, Any]: """ Extract Sentry-specific span attributes and make them look the way Sentry expects. """ - extracted_attrs = {} # type: dict[str, Any] + extracted_attrs: dict[str, Any] = {} for attr, value in (span.attributes or {}).items(): if attr.startswith(namespace): @@ -314,8 +328,9 @@ def extract_span_attributes(span, namespace): return extracted_attrs -def get_trace_context(span, span_data=None): - # type: (ReadableSpan, Optional[OtelExtractedSpanData]) -> dict[str, Any] +def get_trace_context( + span: ReadableSpan, span_data: Optional[OtelExtractedSpanData] = None +) -> dict[str, Any]: if not span.context: return {} @@ -328,13 +343,13 @@ def get_trace_context(span, span_data=None): (op, _, status, _, origin) = span_data - trace_context = { + trace_context: dict[str, Any] = { "trace_id": trace_id, "span_id": span_id, "parent_span_id": parent_span_id, "op": op, "origin": origin or DEFAULT_SPAN_ORIGIN, - } # type: dict[str, Any] + } if status: trace_context["status"] = status @@ -350,8 +365,7 @@ def get_trace_context(span, span_data=None): return trace_context -def trace_state_from_baggage(baggage): - # type: (Baggage) -> TraceState +def trace_state_from_baggage(baggage: Baggage) -> TraceState: items = [] for k, v in baggage.sentry_items.items(): key = Baggage.SENTRY_PREFIX + quote(k) @@ -360,13 +374,11 @@ def trace_state_from_baggage(baggage): return TraceState(items) -def baggage_from_trace_state(trace_state): - # type: (TraceState) -> Baggage +def baggage_from_trace_state(trace_state: TraceState) -> Baggage: return Baggage(dsc_from_trace_state(trace_state)) -def serialize_trace_state(trace_state): - # type: (TraceState) -> str +def serialize_trace_state(trace_state: TraceState) -> str: sentry_items = [] for k, v in trace_state.items(): if Baggage.SENTRY_PREFIX_REGEX.match(k): @@ -374,8 +386,7 @@ def serialize_trace_state(trace_state): return ",".join(key + "=" + value for key, value in sentry_items) -def dsc_from_trace_state(trace_state): - # type: (TraceState) -> dict[str, str] +def dsc_from_trace_state(trace_state: TraceState) -> dict[str, str]: dsc = {} for k, v in trace_state.items(): if Baggage.SENTRY_PREFIX_REGEX.match(k): @@ -384,16 +395,14 @@ def dsc_from_trace_state(trace_state): return dsc -def has_incoming_trace(trace_state): - # type: (TraceState) -> bool +def has_incoming_trace(trace_state: TraceState) -> bool: """ The existence of a sentry-trace_id in the baggage implies we continued an upstream trace. """ return (Baggage.SENTRY_PREFIX + "trace_id") in trace_state -def get_trace_state(span): - # type: (Union[AbstractSpan, ReadableSpan]) -> TraceState +def get_trace_state(span: Union[AbstractSpan, ReadableSpan]) -> TraceState: """ Get the existing trace_state with sentry items or populate it if we are the head SDK. @@ -451,34 +460,45 @@ def get_trace_state(span): return trace_state -def get_sentry_meta(span, key): - # type: (Union[AbstractSpan, ReadableSpan], str) -> Any +def get_sentry_meta(span: Union[AbstractSpan, ReadableSpan], key: str) -> Any: sentry_meta = getattr(span, "_sentry_meta", None) return sentry_meta.get(key) if sentry_meta else None -def set_sentry_meta(span, key, value): - # type: (Union[AbstractSpan, ReadableSpan], str, Any) -> None +def set_sentry_meta( + span: Union[AbstractSpan, ReadableSpan], key: str, value: Any +) -> None: sentry_meta = getattr(span, "_sentry_meta", {}) sentry_meta[key] = value span._sentry_meta = sentry_meta # type: ignore[union-attr] -def delete_sentry_meta(span): - # type: (Union[AbstractSpan, ReadableSpan]) -> None +def delete_sentry_meta(span: Union[AbstractSpan, ReadableSpan]) -> None: try: del span._sentry_meta # type: ignore[union-attr] except AttributeError: pass -def get_profile_context(span): - # type: (ReadableSpan) -> Optional[dict[str, str]] +def get_profile_context(span: ReadableSpan) -> Optional[dict[str, str]]: if not span.attributes: return None - profiler_id = cast("Optional[str]", span.attributes.get(SPANDATA.PROFILER_ID)) + profiler_id = get_typed_attribute(span.attributes, SPANDATA.PROFILER_ID, str) if profiler_id is None: return None return {"profiler_id": profiler_id} + + +def get_typed_attribute( + attributes: Mapping[str, Any], key: str, type: Type[T] +) -> Optional[T]: + """ + helper method to coerce types of attribute values + """ + value = attributes.get(key) + if value is not None and isinstance(value, type): + return value + else: + return None diff --git a/sentry_sdk/profiler/transaction_profiler.py b/sentry_sdk/profiler/transaction_profiler.py index 095ce2f2f9..623f57f2c3 100644 --- a/sentry_sdk/profiler/transaction_profiler.py +++ b/sentry_sdk/profiler/transaction_profiler.py @@ -281,7 +281,8 @@ def _set_initial_sampling_decision(self, sampling_context): self.sampled = False return - if not is_valid_sample_rate(sample_rate, source="Profiling"): + sample_rate = is_valid_sample_rate(sample_rate, source="Profiling") + if sample_rate is None: logger.warning( "[Profiling] Discarding profile because of invalid sample rate." ) @@ -291,14 +292,14 @@ def _set_initial_sampling_decision(self, sampling_context): # Now we roll the dice. random.random is inclusive of 0, but not of 1, # so strict < is safe here. In case sample_rate is a boolean, cast it # to a float (True becomes 1.0 and False becomes 0.0) - self.sampled = random.random() < float(sample_rate) + self.sampled = random.random() < sample_rate if self.sampled: logger.debug("[Profiling] Initializing profile") else: logger.debug( "[Profiling] Discarding profile because it's not included in the random sample (sample rate = {sample_rate})".format( - sample_rate=float(sample_rate) + sample_rate=sample_rate ) ) diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index a092f2e40a..fecb82e09e 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -752,7 +752,7 @@ def get_current_span( def _generate_sample_rand( trace_id: Optional[str], interval: tuple[float, float] = (0.0, 1.0), -) -> Optional[Decimal]: +) -> Decimal: """Generate a sample_rand value from a trace ID. The generated value will be pseudorandomly chosen from the provided diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index 27d46e9e58..746d1eae54 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -1561,10 +1561,11 @@ def parse_url(url: str, sanitize: bool = True) -> ParsedUrl: ) -def is_valid_sample_rate(rate: Any, source: str) -> bool: +def is_valid_sample_rate(rate: Any, source: str) -> Optional[float]: """ Checks the given sample rate to make sure it is valid type and value (a boolean or a number between 0 and 1, inclusive). + Returns the final float value to use if valid. """ # both booleans and NaN are instances of Real, so a) checking for Real @@ -1576,7 +1577,7 @@ def is_valid_sample_rate(rate: Any, source: str) -> bool: source=source, rate=rate, type=type(rate) ) ) - return False + return None # in case rate is a boolean, it will get cast to 1 if it's True and 0 if it's False rate = float(rate) @@ -1586,9 +1587,9 @@ def is_valid_sample_rate(rate: Any, source: str) -> bool: source=source, rate=rate ) ) - return False + return None - return True + return rate def match_regex_list( diff --git a/tests/integrations/logging/test_logging.py b/tests/integrations/logging/test_logging.py index d1f5d448b6..931c58d04f 100644 --- a/tests/integrations/logging/test_logging.py +++ b/tests/integrations/logging/test_logging.py @@ -259,8 +259,8 @@ def test_logging_captured_warnings(sentry_init, capture_events, recwarn): assert events[1]["logentry"]["params"] == [] # Using recwarn suppresses the "third" warning in the test output - assert len(recwarn) == 1 - assert str(recwarn[0].message) == "third" + third_warnings = [w for w in recwarn if str(w.message) == "third"] + assert len(third_warnings) == 1 def test_ignore_logger(sentry_init, capture_events): diff --git a/tests/test_utils.py b/tests/test_utils.py index e5bad4fa72..963b937380 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -493,7 +493,7 @@ def test_accepts_valid_sample_rate(rate): with mock.patch.object(logger, "warning", mock.Mock()): result = is_valid_sample_rate(rate, source="Testing") assert logger.warning.called is False - assert result is True + assert result == float(rate) @pytest.mark.parametrize( @@ -514,7 +514,7 @@ def test_warns_on_invalid_sample_rate(rate, StringContaining): # noqa: N803 with mock.patch.object(logger, "warning", mock.Mock()): result = is_valid_sample_rate(rate, source="Testing") logger.warning.assert_any_call(StringContaining("Given sample rate is invalid")) - assert result is False + assert result is None @pytest.mark.parametrize(