diff --git a/sentry_sdk/integrations/asyncio.py b/sentry_sdk/integrations/asyncio.py index b515345f1a..f6f1d57c8b 100644 --- a/sentry_sdk/integrations/asyncio.py +++ b/sentry_sdk/integrations/asyncio.py @@ -4,7 +4,12 @@ import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations import Integration, DidNotEnable -from sentry_sdk.utils import event_from_exception, logger, reraise +from sentry_sdk.utils import ( + event_from_exception, + logger, + reraise, + is_internal_task, +) from sentry_sdk.transport import AsyncHttpTransport try: @@ -69,6 +74,33 @@ def _patched_close() -> None: loop._sentry_flush_patched = True # type: ignore +def _create_task_with_factory( + orig_task_factory: Any, + loop: asyncio.AbstractEventLoop, + coro: Coroutine[Any, Any, Any], + **kwargs: Any, +) -> asyncio.Task[Any]: + task = None + + # Trying to use user set task factory (if there is one) + if orig_task_factory: + task = orig_task_factory(loop, coro, **kwargs) + + if task is None: + # The default task factory in `asyncio` does not have its own function + # but is just a couple of lines in `asyncio.base_events.create_task()` + # Those lines are copied here. + + # WARNING: + # If the default behavior of the task creation in asyncio changes, + # this will break! + task = Task(coro, loop=loop, **kwargs) + if task._source_traceback: # type: ignore + del task._source_traceback[-1] # type: ignore + + return task + + def patch_asyncio() -> None: orig_task_factory = None try: @@ -81,6 +113,14 @@ def _sentry_task_factory( **kwargs: Any, ) -> asyncio.Future[Any]: + # Check if this is an internal Sentry task + is_internal = is_internal_task() + + if is_internal: + return _create_task_with_factory( + orig_task_factory, loop, coro, **kwargs + ) + async def _task_with_sentry_span_creation() -> Any: result = None @@ -98,25 +138,9 @@ async def _task_with_sentry_span_creation() -> Any: return result - task = None - - # Trying to use user set task factory (if there is one) - if orig_task_factory: - task = orig_task_factory( - loop, _task_with_sentry_span_creation(), **kwargs - ) - - if task is None: - # The default task factory in `asyncio` does not have its own function - # but is just a couple of lines in `asyncio.base_events.create_task()` - # Those lines are copied here. - - # WARNING: - # If the default behavior of the task creation in asyncio changes, - # this will break! - task = Task(_task_with_sentry_span_creation(), loop=loop, **kwargs) - if task._source_traceback: # type: ignore - del task._source_traceback[-1] # type: ignore + task = _create_task_with_factory( + orig_task_factory, loop, _task_with_sentry_span_creation(), **kwargs + ) # Set the task name to include the original coroutine's name try: diff --git a/sentry_sdk/transport.py b/sentry_sdk/transport.py index 5c2c864198..6d7e4c4f84 100644 --- a/sentry_sdk/transport.py +++ b/sentry_sdk/transport.py @@ -39,7 +39,12 @@ import certifi from sentry_sdk.consts import EndpointType -from sentry_sdk.utils import Dsn, logger, capture_internal_exceptions +from sentry_sdk.utils import ( + Dsn, + logger, + capture_internal_exceptions, + mark_sentry_task_internal, +) from sentry_sdk.worker import BackgroundWorker, Worker, AsyncWorker from sentry_sdk.envelope import Envelope, Item, PayloadRef @@ -901,7 +906,8 @@ def kill(self: Self) -> Optional[asyncio.Task[None]]: # type: ignore self._worker.kill() try: # Return the pool cleanup task so caller can await it if needed - return self.loop.create_task(self._pool.aclose()) # type: ignore + with mark_sentry_task_internal(): + return self.loop.create_task(self._pool.aclose()) # type: ignore except RuntimeError: logger.warning("Event loop not running, aborting kill.") return None diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index 0842749baf..ee2cbfb5f1 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations import base64 +import contextvars import json import linecache import logging @@ -12,6 +13,7 @@ import threading import time from collections import namedtuple +from contextlib import contextmanager from datetime import datetime, timezone from decimal import Decimal from functools import partial, partialmethod, wraps @@ -72,6 +74,25 @@ _installed_modules = None +_is_sentry_internal_task = contextvars.ContextVar( + "is_sentry_internal_task", default=False +) + + +def is_internal_task(): + return _is_sentry_internal_task.get() + + +@contextmanager +def mark_sentry_task_internal(): + """Context manager to mark a task as Sentry internal.""" + token = _is_sentry_internal_task.set(True) + try: + yield + finally: + _is_sentry_internal_task.reset(token) + + BASE64_ALPHABET = re.compile(r"^[a-zA-Z0-9/+=]*$") FALSY_ENV_VALUES = frozenset(("false", "f", "n", "no", "off", "0")) diff --git a/sentry_sdk/worker.py b/sentry_sdk/worker.py index c8dbbb2d73..8a23fa3ee1 100644 --- a/sentry_sdk/worker.py +++ b/sentry_sdk/worker.py @@ -6,7 +6,7 @@ from time import sleep, time from sentry_sdk._queue import Queue, FullError -from sentry_sdk.utils import logger +from sentry_sdk.utils import logger, mark_sentry_task_internal from sentry_sdk.consts import DEFAULT_QUEUE_SIZE from typing import TYPE_CHECKING @@ -231,7 +231,8 @@ def start(self) -> None: self._loop = asyncio.get_running_loop() if self._queue is None: self._queue = asyncio.Queue(maxsize=self._queue_size) - self._task = self._loop.create_task(self._target()) + with mark_sentry_task_internal(): + self._task = self._loop.create_task(self._target()) self._task_for_pid = os.getpid() except RuntimeError: # There is no event loop running @@ -273,7 +274,8 @@ async def _wait_flush(self, timeout: float, callback: Optional[Any] = None) -> N def flush(self, timeout: float, callback: Optional[Any] = None) -> Optional[asyncio.Task[None]]: # type: ignore[override] if self.is_alive and timeout > 0.0 and self._loop and self._loop.is_running(): - return self._loop.create_task(self._wait_flush(timeout, callback)) + with mark_sentry_task_internal(): + return self._loop.create_task(self._wait_flush(timeout, callback)) return None def submit(self, callback: Callable[[], Any]) -> bool: @@ -295,7 +297,8 @@ async def _target(self) -> None: self._queue.task_done() break # Firing tasks instead of awaiting them allows for concurrent requests - task = asyncio.create_task(self._process_callback(callback)) + with mark_sentry_task_internal(): + task = asyncio.create_task(self._process_callback(callback)) # Create a strong reference to the task so it can be cancelled on kill # and does not get garbage collected while running self._active_tasks.add(task) diff --git a/tests/integrations/asyncio/test_asyncio.py b/tests/integrations/asyncio/test_asyncio.py index 5c329f8185..42d8626ff3 100644 --- a/tests/integrations/asyncio/test_asyncio.py +++ b/tests/integrations/asyncio/test_asyncio.py @@ -8,6 +8,8 @@ import sentry_sdk from sentry_sdk.consts import OP from sentry_sdk.integrations.asyncio import AsyncioIntegration, patch_asyncio +from sentry_sdk.utils import mark_sentry_task_internal + try: from contextvars import Context, ContextVar @@ -379,6 +381,55 @@ async def test_span_origin( assert event["spans"][0]["origin"] == "auto.function.asyncio" +@minimum_python_38 +@pytest.mark.asyncio(loop_scope="module") +async def test_internal_tasks_not_wrapped(sentry_init, capture_events): + + sentry_init(integrations=[AsyncioIntegration()], traces_sample_rate=1.0) + events = capture_events() + + # Create a user task that should be wrapped + async def user_task(): + await asyncio.sleep(0.01) + return "user_result" + + # Create an internal task that should NOT be wrapped + async def internal_task(): + await asyncio.sleep(0.01) + return "internal_result" + + with sentry_sdk.start_transaction(name="test_transaction"): + user_task_obj = asyncio.create_task(user_task()) + + with mark_sentry_task_internal(): + internal_task_obj = asyncio.create_task(internal_task()) + + user_result = await user_task_obj + internal_result = await internal_task_obj + + assert user_result == "user_result" + assert internal_result == "internal_result" + + assert len(events) == 1 + transaction = events[0] + + user_spans = [] + internal_spans = [] + + for span in transaction.get("spans", []): + if "user_task" in span.get("description", ""): + user_spans.append(span) + elif "internal_task" in span.get("description", ""): + internal_spans.append(span) + + assert ( + len(user_spans) > 0 + ), f"User task should have been traced. All spans: {[s.get('description') for s in transaction.get('spans', [])]}" + assert ( + len(internal_spans) == 0 + ), f"Internal task should NOT have been traced. All spans: {[s.get('description') for s in transaction.get('spans', [])]}" + + @minimum_python_38 def test_loop_close_patching(sentry_init): sentry_init(integrations=[AsyncioIntegration()]) @@ -405,6 +456,12 @@ def test_loop_close_flushes_async_transport(sentry_init): sentry_init(integrations=[AsyncioIntegration()]) + # Save the current event loop to restore it later + try: + original_loop = asyncio.get_event_loop() + except RuntimeError: + original_loop = None + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -415,14 +472,16 @@ def test_loop_close_flushes_async_transport(sentry_init): mock_client = Mock() mock_transport = Mock(spec=AsyncHttpTransport) mock_client.transport = mock_transport - mock_client.close = AsyncMock(return_value=None) + mock_client.close_async = AsyncMock(return_value=None) with patch("sentry_sdk.get_client", return_value=mock_client): loop.close() - mock_client.close.assert_called_once() - mock_client.close.assert_awaited_once() + mock_client.close_async.assert_called_once() + mock_client.close_async.assert_awaited_once() - except Exception: + finally: if not loop.is_closed(): loop.close() + if original_loop: + asyncio.set_event_loop(original_loop)