From 1b8ec101b453fee3d5dc29b3af45b3cb2a6bd623 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Thu, 16 Oct 2025 19:57:03 -0700 Subject: [PATCH 1/4] support code attributes for kafka, celery and pika --- aws-opentelemetry-distro/pyproject.toml | 1 + .../distro/patches/_aio_pika_patches.py | 60 +++ .../distro/patches/_celery_patches.py | 139 ++++++ .../distro/patches/_django_patches.py | 5 +- .../distro/patches/_instrumentation_patch.py | 23 +- .../distro/patches/_pika_patches.py | 62 +++ .../distro/patches/test_aio_pika_patches.py | 102 +++++ .../distro/patches/test_celery_patches.py | 422 ++++++++++++++++++ .../distro/patches/test_django_patches.py | 17 +- .../distro/patches/test_pika_patches.py | 375 ++++++++++++++++ 10 files changed, 1194 insertions(+), 12 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_aio_pika_patches.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_celery_patches.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_pika_patches.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_celery_patches.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_pika_patches.py diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 414b09221..0da377c48 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "opentelemetry-instrumentation-aws-lambda == 0.54b1", "opentelemetry-instrumentation-aio-pika == 0.54b1", "opentelemetry-instrumentation-aiohttp-client == 0.54b1", + "opentelemetry-instrumentation-aiokafka == 0.54b1", "opentelemetry-instrumentation-aiopg == 0.54b1", "opentelemetry-instrumentation-asgi == 0.54b1", "opentelemetry-instrumentation-asyncpg == 0.54b1", diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_aio_pika_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_aio_pika_patches.py new file mode 100644 index 000000000..fdd322ba5 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_aio_pika_patches.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Patches for OpenTelemetry Aio-Pika instrumentation to add code correlation support. +""" + +import functools +import logging + +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status +from amazon.opentelemetry.distro.code_correlation.utils import add_code_attributes_to_span +from opentelemetry import trace + +logger = logging.getLogger(__name__) + + +def patch_callback_decorator_decorate(original_decorate): + """Patch CallbackDecorator.decorate to add code attributes to span.""" + + @functools.wraps(original_decorate) + def patched_decorate(self, callback): + # Decorate the original callback to add code attributes + async def enhanced_callback(message): + # Get current active span + current_span = trace.get_current_span() + if current_span and current_span.is_recording(): + try: + add_code_attributes_to_span(current_span, callback) + except Exception: # pylint: disable=broad-exception-caught + pass + + # Call original callback + return await callback(message) + + # Call original decorate method with our enhanced callback + return original_decorate(self, enhanced_callback) + + return patched_decorate + + +def _apply_aio_pika_instrumentation_patches(): + """Apply aio-pika patches if code correlation is enabled.""" + try: + if get_code_correlation_enabled_status() is not True: + return + + # Import CallbackDecorator inside function to allow proper testing + try: + # pylint: disable=import-outside-toplevel + from opentelemetry.instrumentation.aio_pika.callback_decorator import CallbackDecorator + except ImportError: + logger.warning("Failed to apply Aio-Pika patches: CallbackDecorator not available") + return + + # Patch CallbackDecorator.decorate + CallbackDecorator.decorate = patch_callback_decorator_decorate(CallbackDecorator.decorate) + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning("Failed to apply Aio-Pika patches: %s", exc) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_celery_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_celery_patches.py new file mode 100644 index 000000000..f3fdbe047 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_celery_patches.py @@ -0,0 +1,139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Patches for OpenTelemetry Celery instrumentation to add code correlation support. + +This module provides patches to enhance the Celery instrumentation with code correlation +capabilities, allowing tracking of user code that is executed within Celery tasks. +""" + +import functools +import logging +from typing import Any, Callable, Optional + +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status +from amazon.opentelemetry.distro.code_correlation.utils import add_code_attributes_to_span + +logger = logging.getLogger(__name__) + +# Import at module level to avoid pylint import-outside-toplevel +try: + from opentelemetry.instrumentation.celery import CeleryInstrumentor + from opentelemetry.instrumentation.celery import utils as celery_utils +except ImportError: + celery_utils = None + CeleryInstrumentor = None + + +def _extract_task_function(task) -> Optional[Callable[..., Any]]: # pylint: disable=too-many-return-statements + """ + Extract the actual user function from a Celery task object. + + Args: + task: The Celery task object + + Returns: + The underlying user function if found, None otherwise + """ + if task is None: + return None + + try: + # For regular function-based tasks, the actual function is stored in task.run + if hasattr(task, "run") and callable(task.run): + func = task.run + if hasattr(func, "__func__"): + return func.__func__ + if func.__name__ != "run": # Avoid returning generic run methods + return func + + # For function-based tasks, the original function might be stored differently + if hasattr(task, "__call__") and callable(task.__call__): + func = task.__call__ + if hasattr(func, "__func__") and func.__func__.__name__ != "__call__": + return func.__func__ + if func.__name__ != "__call__": + return func + + # Try to get the original function from __wrapped__ attribute + if hasattr(task, "__wrapped__") and callable(task.__wrapped__): + return task.__wrapped__ + + except Exception: # pylint: disable=broad-exception-caught + pass + + return None + + +def _add_code_correlation_to_span(span, task) -> None: + """ + Add code correlation attributes to a span for a Celery task. + + Args: + span: The OpenTelemetry span to add attributes to + task: The Celery task object + """ + try: + if span is None or not span.is_recording(): + return + + user_function = _extract_task_function(task) + if user_function is not None: + add_code_attributes_to_span(span, user_function) + + except Exception: # pylint: disable=broad-exception-caught + pass + + +def patch_celery_prerun(original_trace_prerun: Callable) -> Callable: + """ + Patch the Celery _trace_prerun method to add code correlation support. + + Args: + original_trace_prerun: The original _trace_prerun method to wrap + + Returns: + The patched _trace_prerun method + """ + + @functools.wraps(original_trace_prerun) + def patched_trace_prerun(self, *args, **kwargs): + result = original_trace_prerun(self, *args, **kwargs) + + try: + task = kwargs.get("task") + task_id = kwargs.get("task_id") + + if task is not None and task_id is not None and celery_utils is not None: + ctx = celery_utils.retrieve_context(task, task_id) + if ctx is not None: + span, _, _ = ctx + if span is not None: + _add_code_correlation_to_span(span, task) + + except Exception: # pylint: disable=broad-exception-caught + pass + + return result + + return patched_trace_prerun + + +def _apply_celery_instrumentation_patches(): + """ + Apply code correlation patches to the Celery instrumentation. + """ + try: + if get_code_correlation_enabled_status() is not True: + return + + if CeleryInstrumentor is None: + logger.warning("Failed to apply Celery patches: CeleryInstrumentor not available") + return + + original_trace_prerun = CeleryInstrumentor._trace_prerun + CeleryInstrumentor._trace_prerun = patch_celery_prerun(original_trace_prerun) + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning("Failed to apply Celery instrumentation patches: %s", exc) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_django_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_django_patches.py index f4e266a71..0b1c40494 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_django_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_django_patches.py @@ -62,6 +62,7 @@ def patched_process_view( ): # pylint: disable=too-many-locals,too-many-nested-blocks,too-many-branches """Patched process_view method to add code attributes to the span.""" # First call the original process_view method + # pylint: disable=assignment-from-none result = original_process_view(self, request, view_func, *args, **kwargs) # Add code attributes if we have a span and view function @@ -120,12 +121,12 @@ def patched_instrument(self, **kwargs): _patch_django_middleware() # Call the original _instrument method - original_instrument(self, **kwargs) + original_instrument(self, **kwargs) # pylint: disable=assignment-from-none def patched_uninstrument(self, **kwargs): """Patched _uninstrument method with Django middleware patch restoration""" # Call the original _uninstrument method first - original_uninstrument(self, **kwargs) + original_uninstrument(self, **kwargs) # pylint: disable=assignment-from-none # Restore original Django middleware _unpatch_django_middleware() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py index 9dec53a55..9765b3603 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py @@ -14,7 +14,7 @@ _logger: Logger = getLogger(__name__) -def apply_instrumentation_patches() -> None: +def apply_instrumentation_patches() -> None: # pylint: disable=too-many-branches """Apply patches to upstream instrumentation libraries. This method is invoked to apply changes to upstream instrumentation libraries, typically when changes to upstream @@ -92,6 +92,27 @@ def apply_instrumentation_patches() -> None: _apply_django_instrumentation_patches() + if is_installed("celery"): + # pylint: disable=import-outside-toplevel + # Delay import to only occur if patches is safe to apply (e.g. the instrumented library is installed). + from amazon.opentelemetry.distro.patches._celery_patches import _apply_celery_instrumentation_patches + + _apply_celery_instrumentation_patches() + + if is_installed("pika"): + # pylint: disable=import-outside-toplevel + # Delay import to only occur if patches is safe to apply (e.g. the instrumented library is installed). + from amazon.opentelemetry.distro.patches._pika_patches import _apply_pika_instrumentation_patches + + _apply_pika_instrumentation_patches() + + if is_installed("aio-pika"): + # pylint: disable=import-outside-toplevel + # Delay import to only occur if patches is safe to apply (e.g. the instrumented library is installed). + from amazon.opentelemetry.distro.patches._aio_pika_patches import _apply_aio_pika_instrumentation_patches + + _apply_aio_pika_instrumentation_patches() + # No need to check if library is installed as this patches opentelemetry.sdk, # which must be installed for the distro to work at all. _apply_resource_detector_patches() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_pika_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_pika_patches.py new file mode 100644 index 000000000..5d76c818a --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_pika_patches.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Patches for OpenTelemetry Pika instrumentation to add code correlation support. +""" + +import functools +import logging + +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status +from amazon.opentelemetry.distro.code_correlation.utils import add_code_attributes_to_span + +logger = logging.getLogger(__name__) + + +def patch_decorate_callback(original_decorate_callback): + """Patch _decorate_callback to add code attributes to span.""" + + @functools.wraps(original_decorate_callback) + def patched_decorate_callback(callback, tracer, task_name, consume_hook): + # Create an enhanced consume_hook that adds code attributes + def enhanced_consume_hook(span, body, properties): + # First add code attributes for the callback + if span and span.is_recording(): + try: + add_code_attributes_to_span(span, callback) + except Exception: # pylint: disable=broad-exception-caught + pass + + # Then call the original consume_hook if it exists + if consume_hook: + try: + consume_hook(span, body, properties) + except Exception: # pylint: disable=broad-exception-caught + pass + + # Call original with our enhanced hook + return original_decorate_callback(callback, tracer, task_name, enhanced_consume_hook) + + return patched_decorate_callback + + +def _apply_pika_instrumentation_patches(): + """Apply pika patches if code correlation is enabled.""" + try: + if get_code_correlation_enabled_status() is not True: + return + + # Import pika_utils inside function to allow proper testing + try: + # pylint: disable=import-outside-toplevel + from opentelemetry.instrumentation.pika import utils as pika_utils + except ImportError: + logger.warning("Failed to apply Pika patches: pika utils not available") + return + + # Patch _decorate_callback + pika_utils._decorate_callback = patch_decorate_callback(pika_utils._decorate_callback) + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning("Failed to apply Pika patches: %s", exc) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py new file mode 100644 index 000000000..3240f9b9b --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from unittest.mock import Mock, patch + +from amazon.opentelemetry.distro.patches._aio_pika_patches import ( + _apply_aio_pika_instrumentation_patches, + patch_callback_decorator_decorate, +) + + +class TestAioPikaPatches(unittest.TestCase): + + def test_patch_callback_decorator_decorate(self): + # Mock the original decorate method + original_decorate = Mock() + + # Mock CallbackDecorator instance + mock_decorator = Mock() + mock_decorator._tracer = Mock() + mock_decorator._get_span = Mock(return_value=Mock()) + + # Mock callback function + mock_callback = Mock() + + # Create the patched decorate method + patched_decorate = patch_callback_decorator_decorate(original_decorate) + + # Call the patched method + result = patched_decorate(mock_decorator, mock_callback) + + # Verify original_decorate was called once (with enhanced callback, not original) + original_decorate.assert_called_once() + # Check that the first argument is the decorator instance + call_args = original_decorate.call_args[0] + self.assertEqual(call_args[0], mock_decorator) + # Check that the second argument is a callable (the enhanced callback) + self.assertTrue(callable(call_args[1])) + + # Verify we got a function back (the enhanced decorated callback) + self.assertTrue(callable(result)) + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.get_code_correlation_enabled_status") + def test_apply_aio_pika_instrumentation_patches_disabled(self, mock_get_status): + mock_get_status.return_value = False + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.aio_pika": None, + "opentelemetry.instrumentation.aio_pika.callback_decorator": None, + }, + ): + # Should not raise exception when code correlation is disabled + _apply_aio_pika_instrumentation_patches() + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.get_code_correlation_enabled_status") + def test_apply_aio_pika_instrumentation_patches_enabled(self, mock_get_status): + mock_get_status.return_value = True + + # Mock CallbackDecorator + mock_callback_decorator = Mock() + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.aio_pika": Mock(), + "opentelemetry.instrumentation.aio_pika.callback_decorator": Mock( + CallbackDecorator=mock_callback_decorator + ), + }, + ): + _apply_aio_pika_instrumentation_patches() + + # Verify the decorate method was patched + self.assertTrue(hasattr(mock_callback_decorator, "decorate")) + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.get_code_correlation_enabled_status") + def test_apply_aio_pika_instrumentation_patches_import_error(self, mock_get_status): + mock_get_status.return_value = True + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.aio_pika": None, + "opentelemetry.instrumentation.aio_pika.callback_decorator": None, + }, + ): + # Should not raise exception when import fails + _apply_aio_pika_instrumentation_patches() + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.logger") + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.get_code_correlation_enabled_status") + def test_apply_aio_pika_instrumentation_patches_exception_handling(self, mock_get_status, mock_logger): + mock_get_status.side_effect = Exception("Test exception") + + # Should handle exceptions gracefully + _apply_aio_pika_instrumentation_patches() + + # Verify warning was logged + mock_logger.warning.assert_called_once() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_celery_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_celery_patches.py new file mode 100644 index 000000000..81abf0068 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_celery_patches.py @@ -0,0 +1,422 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +from amazon.opentelemetry.distro.patches._celery_patches import ( + _add_code_correlation_to_span, + _apply_celery_instrumentation_patches, + _extract_task_function, + patch_celery_prerun, +) +from opentelemetry.test.test_base import TestBase + + +class TestCeleryPatches(TestBase): + """Test Celery patches functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + def tearDown(self): + """Clean up after tests.""" + super().tearDown() + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._celery_patches.CeleryInstrumentor") + def test_apply_celery_instrumentation_patches_enabled(self, mock_instrumentor, mock_get_status): + """Test Celery instrumentation patches when code correlation is enabled.""" + mock_get_status.return_value = True + + # Mock CeleryInstrumentor + mock_original_trace_prerun = Mock() + mock_instrumentor._trace_prerun = mock_original_trace_prerun + + _apply_celery_instrumentation_patches() + + mock_get_status.assert_called_once() + # Verify that the _trace_prerun method was replaced + self.assertNotEqual(mock_instrumentor._trace_prerun, mock_original_trace_prerun) + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + def test_apply_celery_instrumentation_patches_disabled(self, mock_get_status): + """Test Celery instrumentation patches when code correlation is disabled.""" + mock_get_status.return_value = False + + with patch("amazon.opentelemetry.distro.patches._celery_patches.logger") as mock_logger: + _apply_celery_instrumentation_patches() + + mock_get_status.assert_called_once() + # No warning should be logged since it returns early + mock_logger.warning.assert_not_called() + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + def test_apply_celery_instrumentation_patches_none_status(self, mock_get_status): + """Test Celery instrumentation patches when status is None.""" + mock_get_status.return_value = None + + with patch("amazon.opentelemetry.distro.patches._celery_patches.logger") as mock_logger: + _apply_celery_instrumentation_patches() + + mock_get_status.assert_called_once() + # No warning should be logged since it returns early + mock_logger.warning.assert_not_called() + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._celery_patches.logger") + def test_apply_celery_instrumentation_patches_import_error(self, mock_logger, mock_get_status): + """Test Celery instrumentation patches with import error.""" + mock_get_status.return_value = True + + # Patch CeleryInstrumentor to None to simulate import failure + with patch("amazon.opentelemetry.distro.patches._celery_patches.CeleryInstrumentor", None): + _apply_celery_instrumentation_patches() + + mock_get_status.assert_called_once() + mock_logger.warning.assert_called_once() + args = mock_logger.warning.call_args[0] + self.assertIn("Failed to apply Celery patches: CeleryInstrumentor not available", args[0]) + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._celery_patches.logger") + def test_apply_celery_instrumentation_patches_exception(self, mock_logger, mock_get_status): + """Test Celery instrumentation patches with general exception.""" + mock_get_status.side_effect = Exception("Unexpected error") + + _apply_celery_instrumentation_patches() + + mock_logger.warning.assert_called_once() + args = mock_logger.warning.call_args[0] + self.assertIn("Failed to apply Celery instrumentation patches", args[0]) + + +class TestExtractTaskFunction(TestBase): + """Test _extract_task_function functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + def test_extract_task_function_none_task(self): + """Test extract task function with None task.""" + result = _extract_task_function(None) + self.assertIsNone(result) + + def test_extract_task_function_with_run_method(self): + """Test extract task function with task that has run method.""" + + def sample_task(): + pass + + mock_task = Mock() + mock_task.run = sample_task + + result = _extract_task_function(mock_task) + self.assertEqual(result, sample_task) + + def test_extract_task_function_with_run_bound_method(self): + """Test extract task function with task that has bound run method.""" + + def sample_function(): + pass + + mock_task = Mock() + mock_run = Mock() + mock_run.__func__ = sample_function + mock_task.run = mock_run + + result = _extract_task_function(mock_task) + self.assertEqual(result, sample_function) + + def test_extract_task_function_with_call_method(self): + """Test extract task function with task that has __call__ method.""" + + def sample_task(): + pass + + mock_task = Mock() + mock_task.run = None # No run method + mock_call = Mock() + mock_call.__func__ = Mock() + mock_call.__func__.__name__ = "sample_function" # Not '__call__' + mock_call.__func__ = sample_task + mock_task.__call__ = mock_call + + result = _extract_task_function(mock_task) + self.assertEqual(result, sample_task) + + def test_extract_task_function_with_call_method_skip_default(self): + """Test extract task function skips default __call__ method.""" + mock_task = Mock() + mock_task.run = None # No run method + mock_call = Mock() + mock_call.__func__ = Mock() + mock_call.__func__.__name__ = "__call__" # Default __call__, should skip + mock_call.__name__ = "__call__" # Also set the direct name to __call__ + mock_task.__call__ = mock_call + # Ensure no __wrapped__ attribute exists + del mock_task.__wrapped__ + + result = _extract_task_function(mock_task) + self.assertIsNone(result) # Should skip default __call__ and return None + + def test_extract_task_function_with_wrapped(self): + """Test extract task function with __wrapped__ attribute.""" + + def sample_task(): + pass + + mock_task = Mock() + mock_task.run = None # No run method + mock_task.__call__ = None # No __call__ method + mock_task.__wrapped__ = sample_task + + result = _extract_task_function(mock_task) + self.assertEqual(result, sample_task) + + def test_extract_task_function_no_methods(self): + """Test extract task function with no extractable methods.""" + mock_task = Mock() + mock_task.run = None + mock_task.__call__ = None + del mock_task.__wrapped__ # Remove __wrapped__ attribute + + result = _extract_task_function(mock_task) + self.assertIsNone(result) + + def test_extract_task_function_exception_handling(self): + """Test extract task function handles exceptions gracefully.""" + mock_task = Mock() + + # Configure accessing the run attribute to raise an exception + def raise_exception(): + raise ValueError("Error accessing run") # pylint: disable=broad-exception-raised + + type(mock_task).run = property(lambda self: raise_exception()) + + result = _extract_task_function(mock_task) + self.assertIsNone(result) + + +class TestAddCodeCorrelationToSpan(TestBase): + """Test _add_code_correlation_to_span functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + def test_add_code_correlation_none_span(self): + """Test add code correlation with None span.""" + mock_task = Mock() + + # Should not raise exception + _add_code_correlation_to_span(None, mock_task) + + def test_add_code_correlation_non_recording_span(self): + """Test add code correlation with non-recording span.""" + mock_span = Mock() + mock_span.is_recording.return_value = False + mock_task = Mock() + + _add_code_correlation_to_span(mock_span, mock_task) + + mock_span.is_recording.assert_called_once() + + @patch("amazon.opentelemetry.distro.patches._celery_patches._extract_task_function") + @patch("amazon.opentelemetry.distro.patches._celery_patches.add_code_attributes_to_span") + def test_add_code_correlation_success(self, mock_add_attributes, mock_extract): + """Test successful code correlation addition.""" + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_task = Mock() + mock_function = Mock() + mock_extract.return_value = mock_function + + _add_code_correlation_to_span(mock_span, mock_task) + + mock_span.is_recording.assert_called_once() + mock_extract.assert_called_once_with(mock_task) + mock_add_attributes.assert_called_once_with(mock_span, mock_function) + + @patch("amazon.opentelemetry.distro.patches._celery_patches._extract_task_function") + def test_add_code_correlation_no_function_extracted(self, mock_extract): + """Test code correlation when no function is extracted.""" + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_task = Mock() + mock_extract.return_value = None + + _add_code_correlation_to_span(mock_span, mock_task) + + mock_span.is_recording.assert_called_once() + mock_extract.assert_called_once_with(mock_task) + + def test_add_code_correlation_exception_handling(self): + """Test code correlation handles exceptions gracefully.""" + mock_span = Mock() + mock_span.is_recording.side_effect = Exception("Span error") + mock_task = Mock() + + # Should not raise exception + _add_code_correlation_to_span(mock_span, mock_task) + + +class TestPatchCeleryPrerun(TestBase): + """Test patch_celery_prerun functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + def test_patch_celery_prerun_wrapper(self): + """Test that patch_celery_prerun creates proper wrapper.""" + original_function = Mock(return_value="original_result") + original_function.__name__ = "original_trace_prerun" + + patched_function = patch_celery_prerun(original_function) + + # Check that functools.wraps was applied + self.assertEqual(patched_function.__name__, "original_trace_prerun") + + # Test calling the patched function + mock_self = Mock() + args = ("arg1", "arg2") + kwargs = {"task": Mock(), "task_id": "test_id"} + + result = patched_function(mock_self, *args, **kwargs) + + # Original function should be called + original_function.assert_called_once_with(mock_self, *args, **kwargs) + self.assertEqual(result, "original_result") + + @patch("amazon.opentelemetry.distro.patches._celery_patches._add_code_correlation_to_span") + def test_patch_celery_prerun_adds_correlation(self, mock_add_correlation): + """Test that patched function adds code correlation.""" + original_function = Mock(return_value="original_result") + mock_task = Mock() + mock_task_id = "test_task_id" + + patched_function = patch_celery_prerun(original_function) + + # Test that the patched function works without errors + # The actual import behavior is tested in integration tests + mock_self = Mock() + kwargs = {"task": mock_task, "task_id": mock_task_id} + + result = patched_function(mock_self, **kwargs) + + # Verify original function was called + original_function.assert_called_once_with(mock_self, **kwargs) + self.assertEqual(result, "original_result") + + # The function should complete without raising exceptions + # which validates that the patch logic doesn't break the flow + + def test_patch_celery_prerun_missing_task(self): + """Test patched function with missing task.""" + original_function = Mock(return_value="original_result") + + patched_function = patch_celery_prerun(original_function) + + mock_self = Mock() + kwargs = {"task_id": "test_id"} # Missing task + + result = patched_function(mock_self, **kwargs) + + # Should still call original function and return result + original_function.assert_called_once_with(mock_self, **kwargs) + self.assertEqual(result, "original_result") + + def test_patch_celery_prerun_missing_task_id(self): + """Test patched function with missing task_id.""" + original_function = Mock(return_value="original_result") + + patched_function = patch_celery_prerun(original_function) + + mock_self = Mock() + kwargs = {"task": Mock()} # Missing task_id + + result = patched_function(mock_self, **kwargs) + + # Should still call original function and return result + original_function.assert_called_once_with(mock_self, **kwargs) + self.assertEqual(result, "original_result") + + def test_patch_celery_prerun_no_context(self): + """Test patched function when retrieve_context returns None.""" + original_function = Mock(return_value="original_result") + mock_task = Mock() + mock_task_id = "test_task_id" + + # Mock the utils.retrieve_context to return None + mock_utils = Mock() + mock_utils.retrieve_context.return_value = None + + patched_function = patch_celery_prerun(original_function) + + with patch.dict( + "sys.modules", + {"opentelemetry.instrumentation.celery.utils": mock_utils}, + ): + mock_self = Mock() + kwargs = {"task": mock_task, "task_id": mock_task_id} + + result = patched_function(mock_self, **kwargs) + + # Should still complete successfully + original_function.assert_called_once_with(mock_self, **kwargs) + self.assertEqual(result, "original_result") + + def test_patch_celery_prerun_exception_handling(self): + """Test patched function handles exceptions gracefully.""" + original_function = Mock(return_value="original_result") + + # Mock that will cause an exception in the patch logic + mock_utils = Mock() + mock_utils.retrieve_context.side_effect = Exception("Context error") + + patched_function = patch_celery_prerun(original_function) + + with patch.dict( + "sys.modules", + {"opentelemetry.instrumentation.celery.utils": mock_utils}, + ): + mock_self = Mock() + kwargs = {"task": Mock(), "task_id": "test_id"} + + result = patched_function(mock_self, **kwargs) + + # Should still call original function and return result despite exception + original_function.assert_called_once_with(mock_self, **kwargs) + self.assertEqual(result, "original_result") + + +class TestCeleryPatchesIntegration(TestBase): + """Test Celery patches integration scenarios.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + @patch("amazon.opentelemetry.distro.patches._celery_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._celery_patches.CeleryInstrumentor") + def test_full_patch_application_flow(self, mock_instrumentor, mock_get_status): + """Test the complete flow of applying Celery patches.""" + mock_get_status.return_value = True + + # Create a realistic mock setup + original_trace_prerun = Mock(__name__="original_trace_prerun") + mock_instrumentor._trace_prerun = original_trace_prerun + + _apply_celery_instrumentation_patches() + + # Verify the method was replaced with a wrapped version + self.assertNotEqual(mock_instrumentor._trace_prerun, original_trace_prerun) + self.assertEqual(mock_instrumentor._trace_prerun.__name__, "original_trace_prerun") + + # Test calling the patched method + mock_self = Mock() + kwargs = {"task": Mock(), "task_id": "test"} + + # Should not raise exceptions + mock_instrumentor._trace_prerun(mock_self, **kwargs) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_django_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_django_patches.py index beaf36437..4cbbe1a6b 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_django_patches.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_django_patches.py @@ -116,7 +116,7 @@ def test_apply_django_code_attributes_patch_import_error(self, mock_logger): _apply_django_code_attributes_patch() # Check that warning was called with the format string and an ImportError mock_logger.warning.assert_called() - args, kwargs = mock_logger.warning.call_args + args, _kwargs = mock_logger.warning.call_args self.assertEqual(args[0], "Failed to apply Django code attributes patch: %s") self.assertIsInstance(args[1], ImportError) self.assertEqual(str(args[1]), "Module not found") @@ -127,7 +127,7 @@ def test_apply_django_code_attributes_patch_exception_handling(self): # Test that the function doesn't raise exceptions even with import failures _apply_django_code_attributes_patch() # Should complete without errors regardless of Django availability - self.assertTrue(True) # If we get here, no exception was raised + # If we get here, no exception was raised @patch("amazon.opentelemetry.distro.patches._django_patches.get_code_correlation_enabled_status", return_value=True) @@ -137,7 +137,8 @@ class TestDjangoRealIntegration(TestBase): def setUp(self): """Set up test fixtures with Django configuration.""" super().setUp() - self.skipTest("Django not available") if not DJANGO_AVAILABLE else None + if not DJANGO_AVAILABLE: + self.skipTest("Django not available") # Configure Django with minimal settings if not settings.configured: @@ -206,10 +207,9 @@ def test_view(request): request.META[span_key] = mock_span # Call process_view method which should trigger the patch - result = middleware.process_view(request, test_view, [], {}) + middleware.process_view(request, test_view, [], {}) - # The result should be None (original process_view returns None) - self.assertIsNone(result) + # The original process_view returns None, so we don't assign result # Verify span methods were called (this confirms the patched code ran) mock_span.is_recording.assert_called() @@ -267,10 +267,9 @@ def mock_view_func(request): # Call process_view method with the class-based view function # This should trigger the class-based view logic where it extracts the handler - result = middleware.process_view(request, mock_view_func, [], {}) + middleware.process_view(request, mock_view_func, [], {}) - # The result should be None (original process_view returns None) - self.assertIsNone(result) + # The original process_view returns None, so we don't assign result # Verify span methods were called (this confirms the patched code ran) mock_span.is_recording.assert_called() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_pika_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_pika_patches.py new file mode 100644 index 000000000..dab5e42d7 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_pika_patches.py @@ -0,0 +1,375 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +from amazon.opentelemetry.distro.patches._pika_patches import ( + _apply_pika_instrumentation_patches, + patch_decorate_callback, +) +from opentelemetry.test.test_base import TestBase + + +class TestPikaPatches(TestBase): + """Test Pika patches functionality.""" + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + def test_apply_pika_instrumentation_patches_enabled(self, mock_get_status): + """Test Pika instrumentation patches when code correlation is enabled.""" + mock_get_status.return_value = True + + # Mock pika utils + mock_utils = Mock() + mock_original_decorate_callback = Mock() + mock_utils._decorate_callback = mock_original_decorate_callback + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.pika": Mock(utils=mock_utils), + "opentelemetry.instrumentation.pika.utils": mock_utils, + }, + ): + _apply_pika_instrumentation_patches() + + mock_get_status.assert_called_once() + # Verify that the _decorate_callback method was replaced + self.assertNotEqual(mock_utils._decorate_callback, mock_original_decorate_callback) + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + def test_apply_pika_instrumentation_patches_disabled(self, mock_get_status): + """Test Pika instrumentation patches when code correlation is disabled.""" + mock_get_status.return_value = False + + with patch("amazon.opentelemetry.distro.patches._pika_patches.logger") as mock_logger: + _apply_pika_instrumentation_patches() + + mock_get_status.assert_called_once() + # No warning should be logged since it returns early + mock_logger.warning.assert_not_called() + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + def test_apply_pika_instrumentation_patches_none_status(self, mock_get_status): + """Test Pika instrumentation patches when status is None.""" + mock_get_status.return_value = None + + with patch("amazon.opentelemetry.distro.patches._pika_patches.logger") as mock_logger: + _apply_pika_instrumentation_patches() + + mock_get_status.assert_called_once() + # No warning should be logged since it returns early + mock_logger.warning.assert_not_called() + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._pika_patches.logger") + def test_apply_pika_instrumentation_patches_import_error(self, mock_logger, mock_get_status): + """Test Pika instrumentation patches with import error.""" + mock_get_status.return_value = True + + # Patch the specific import that would fail + with patch.dict( + "sys.modules", + {"opentelemetry.instrumentation.pika": None, "opentelemetry.instrumentation.pika.utils": None}, + ): + _apply_pika_instrumentation_patches() + + mock_get_status.assert_called_once() + mock_logger.warning.assert_called_once() + args = mock_logger.warning.call_args[0] + self.assertIn("Failed to apply Pika patches", args[0]) + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._pika_patches.logger") + def test_apply_pika_instrumentation_patches_exception(self, mock_logger, mock_get_status): + """Test Pika instrumentation patches with general exception.""" + mock_get_status.side_effect = Exception("Unexpected error") + + _apply_pika_instrumentation_patches() + + mock_logger.warning.assert_called_once() + args = mock_logger.warning.call_args[0] + self.assertIn("Failed to apply Pika patches", args[0]) + + +class TestPatchDecorateCallback(TestBase): + """Test patch_decorate_callback functionality.""" + + def test_patch_decorate_callback_wrapper(self): + """Test that patch_decorate_callback creates proper wrapper.""" + original_function = Mock(return_value="decorated_callback_result") + original_function.__name__ = "original_decorate_callback" + + patched_function = patch_decorate_callback(original_function) + + # Check that functools.wraps was applied + self.assertEqual(patched_function.__name__, "original_decorate_callback") + + # Test calling the patched function + mock_callback = Mock() + mock_tracer = Mock() + task_name = "test_task" + mock_consume_hook = Mock() + + result = patched_function(mock_callback, mock_tracer, task_name, mock_consume_hook) + + # Original function should be called with enhanced consume hook + original_function.assert_called_once() + args, _kwargs = original_function.call_args + self.assertEqual(args[0], mock_callback) + self.assertEqual(args[1], mock_tracer) + self.assertEqual(args[2], task_name) + # The fourth argument should be our enhanced consume hook, not the original + self.assertNotEqual(args[3], mock_consume_hook) + self.assertEqual(result, "decorated_callback_result") + + @patch("amazon.opentelemetry.distro.patches._pika_patches.add_code_attributes_to_span") + def test_enhanced_consume_hook_success(self, mock_add_attributes): + """Test enhanced consume hook with successful code attribute addition.""" + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + mock_consume_hook = Mock() + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function to get the enhanced consume hook + patched_function(mock_callback, Mock(), "test_task", mock_consume_hook) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_body = "test_body" + mock_properties = {"test": "properties"} + + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + # Verify code attributes were added + mock_add_attributes.assert_called_once_with(mock_span, mock_callback) + + # Verify original consume hook was called + mock_consume_hook.assert_called_once_with(mock_span, mock_body, mock_properties) + + @patch("amazon.opentelemetry.distro.patches._pika_patches.add_code_attributes_to_span") + def test_enhanced_consume_hook_non_recording_span(self, mock_add_attributes): + """Test enhanced consume hook with non-recording span.""" + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + mock_consume_hook = Mock() + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function to get the enhanced consume hook + patched_function(mock_callback, Mock(), "test_task", mock_consume_hook) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook with non-recording span + mock_span = Mock() + mock_span.is_recording.return_value = False + mock_body = "test_body" + mock_properties = {"test": "properties"} + + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + # Code attributes should not be added for non-recording span + mock_add_attributes.assert_not_called() + + # Original consume hook should still be called + mock_consume_hook.assert_called_once_with(mock_span, mock_body, mock_properties) + + @patch("amazon.opentelemetry.distro.patches._pika_patches.add_code_attributes_to_span") + def test_enhanced_consume_hook_none_span(self, mock_add_attributes): + """Test enhanced consume hook with None span.""" + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + mock_consume_hook = Mock() + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function to get the enhanced consume hook + patched_function(mock_callback, Mock(), "test_task", mock_consume_hook) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook with None span + mock_body = "test_body" + mock_properties = {"test": "properties"} + + enhanced_consume_hook(None, mock_body, mock_properties) + + # Code attributes should not be added for None span + mock_add_attributes.assert_not_called() + + # Original consume hook should still be called + mock_consume_hook.assert_called_once_with(None, mock_body, mock_properties) + + def test_enhanced_consume_hook_no_original_consume_hook(self): + """Test enhanced consume hook when no original consume hook is provided.""" + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function with no consume hook + patched_function(mock_callback, Mock(), "test_task", None) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_body = "test_body" + mock_properties = {"test": "properties"} + + # Should not raise exception even without original consume hook + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + @patch("amazon.opentelemetry.distro.patches._pika_patches.add_code_attributes_to_span") + def test_enhanced_consume_hook_add_attributes_exception(self, mock_add_attributes): + """Test enhanced consume hook handles add_code_attributes_to_span exceptions.""" + mock_add_attributes.side_effect = Exception("Add attributes error") + + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + mock_consume_hook = Mock() + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function to get the enhanced consume hook + patched_function(mock_callback, Mock(), "test_task", mock_consume_hook) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_body = "test_body" + mock_properties = {"test": "properties"} + + # Should not raise exception despite add_code_attributes_to_span failing + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + # Verify add_code_attributes_to_span was called (and failed) + mock_add_attributes.assert_called_once_with(mock_span, mock_callback) + + # Original consume hook should still be called + mock_consume_hook.assert_called_once_with(mock_span, mock_body, mock_properties) + + def test_enhanced_consume_hook_original_hook_exception(self): + """Test enhanced consume hook handles original consume hook exceptions.""" + original_function = Mock(return_value="decorated_callback_result") + mock_callback = Mock() + mock_consume_hook = Mock() + mock_consume_hook.side_effect = Exception("Original hook error") + + patched_function = patch_decorate_callback(original_function) + + # Call the patched function to get the enhanced consume hook + patched_function(mock_callback, Mock(), "test_task", mock_consume_hook) + + # Get the enhanced consume hook from the call + enhanced_consume_hook = original_function.call_args[0][3] + + # Test the enhanced consume hook + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_body = "test_body" + mock_properties = {"test": "properties"} + + # Should not raise exception despite original consume hook failing + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + # Verify original consume hook was called (and failed) + mock_consume_hook.assert_called_once_with(mock_span, mock_body, mock_properties) + + +class TestPikaPatchesIntegration(TestBase): + """Test Pika patches integration scenarios.""" + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + def test_full_patch_application_flow(self, mock_get_status): + """Test the complete flow of applying Pika patches.""" + mock_get_status.return_value = True + + # Create a realistic mock setup + mock_utils = Mock() + original_decorate_callback = Mock(__name__="original_decorate_callback") + mock_utils._decorate_callback = original_decorate_callback + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.pika": Mock(utils=mock_utils), + "opentelemetry.instrumentation.pika.utils": mock_utils, + }, + ): + _apply_pika_instrumentation_patches() + + # Verify the method was replaced with a wrapped version + self.assertNotEqual(mock_utils._decorate_callback, original_decorate_callback) + self.assertEqual(mock_utils._decorate_callback.__name__, "original_decorate_callback") + + # Test calling the patched method + mock_callback = Mock() + mock_tracer = Mock() + task_name = "test_task" + mock_consume_hook = Mock() + + # Should not raise exceptions + mock_utils._decorate_callback(mock_callback, mock_tracer, task_name, mock_consume_hook) + + # Original function should be called + original_decorate_callback.assert_called_once() + + @patch("amazon.opentelemetry.distro.patches._pika_patches.get_code_correlation_enabled_status") + @patch("amazon.opentelemetry.distro.patches._pika_patches.add_code_attributes_to_span") + def test_end_to_end_enhanced_consume_hook(self, mock_add_attributes, mock_get_status): + """Test end-to-end flow with enhanced consume hook.""" + mock_get_status.return_value = True + + # Create a realistic mock setup + mock_utils = Mock() + original_decorate_callback = Mock(return_value="decorated_result") + mock_utils._decorate_callback = original_decorate_callback + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.pika": Mock(utils=mock_utils), + "opentelemetry.instrumentation.pika.utils": mock_utils, + }, + ): + _apply_pika_instrumentation_patches() + + # Now use the patched method + mock_callback = Mock() + mock_tracer = Mock() + task_name = "test_task" + mock_consume_hook = Mock() + + result = mock_utils._decorate_callback(mock_callback, mock_tracer, task_name, mock_consume_hook) + + # Get the enhanced consume hook that was passed to the original function + enhanced_consume_hook = original_decorate_callback.call_args[0][3] + + # Test using the enhanced consume hook + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_body = "test_body" + mock_properties = {"test": "properties"} + + enhanced_consume_hook(mock_span, mock_body, mock_properties) + + # Verify code attributes were added + mock_add_attributes.assert_called_once_with(mock_span, mock_callback) + + # Verify original consume hook was called + mock_consume_hook.assert_called_once_with(mock_span, mock_body, mock_properties) + + # Verify result was returned + self.assertEqual(result, "decorated_result") From 62912f73a01dac0b2c9a7dc268aa1021a43ef90a Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 17 Oct 2025 13:21:40 -0700 Subject: [PATCH 2/4] improve coverage --- .../distro/patches/test_aio_pika_patches.py | 125 +++++++++++++++++- 1 file changed, 124 insertions(+), 1 deletion(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py index 3240f9b9b..c1378fbf1 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_aio_pika_patches.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import unittest -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from amazon.opentelemetry.distro.patches._aio_pika_patches import ( _apply_aio_pika_instrumentation_patches, @@ -41,6 +42,128 @@ def test_patch_callback_decorator_decorate(self): # Verify we got a function back (the enhanced decorated callback) self.assertTrue(callable(result)) + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.trace.get_current_span") + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.add_code_attributes_to_span") + def test_enhanced_callback_with_no_span(self, mock_add_attributes, mock_get_span): + """Test enhanced_callback when no current span exists""" + # Arrange + mock_get_span.return_value = None + original_decorate = Mock() + mock_callback = AsyncMock() + mock_message = Mock() + + # Create patched decorate function + patched_decorate = patch_callback_decorator_decorate(original_decorate) + + # Call patched decorate to get the enhanced callback + patched_decorate(Mock(), mock_callback) + + # Get the enhanced callback from the call + enhanced_callback = original_decorate.call_args[0][1] + + # Act + asyncio.run(enhanced_callback(mock_message)) + + # Assert + mock_get_span.assert_called_once() + mock_add_attributes.assert_not_called() + mock_callback.assert_called_once_with(mock_message) + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.trace.get_current_span") + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.add_code_attributes_to_span") + def test_enhanced_callback_with_non_recording_span(self, mock_add_attributes, mock_get_span): + """Test enhanced_callback when span is not recording""" + # Arrange + mock_span = Mock() + mock_span.is_recording.return_value = False + mock_get_span.return_value = mock_span + + original_decorate = Mock() + mock_callback = AsyncMock() + mock_message = Mock() + + # Create patched decorate function + patched_decorate = patch_callback_decorator_decorate(original_decorate) + + # Call patched decorate to get the enhanced callback + patched_decorate(Mock(), mock_callback) + + # Get the enhanced callback from the call + enhanced_callback = original_decorate.call_args[0][1] + + # Act + asyncio.run(enhanced_callback(mock_message)) + + # Assert + mock_get_span.assert_called_once() + mock_span.is_recording.assert_called_once() + mock_add_attributes.assert_not_called() + mock_callback.assert_called_once_with(mock_message) + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.trace.get_current_span") + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.add_code_attributes_to_span") + def test_enhanced_callback_with_exception_in_add_attributes(self, mock_add_attributes, mock_get_span): + """Test enhanced_callback when add_code_attributes_to_span raises exception""" + # Arrange + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_get_span.return_value = mock_span + mock_add_attributes.side_effect = Exception("Test exception") + + original_decorate = Mock() + mock_callback = AsyncMock() + mock_message = Mock() + + # Create patched decorate function + patched_decorate = patch_callback_decorator_decorate(original_decorate) + + # Call patched decorate to get the enhanced callback + patched_decorate(Mock(), mock_callback) + + # Get the enhanced callback from the call + enhanced_callback = original_decorate.call_args[0][1] + + # Act + asyncio.run(enhanced_callback(mock_message)) + + # Assert + mock_get_span.assert_called_once() + mock_span.is_recording.assert_called_once() + mock_add_attributes.assert_called_once_with(mock_span, mock_callback) + # Should still call original callback despite exception + mock_callback.assert_called_once_with(mock_message) + + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.trace.get_current_span") + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.add_code_attributes_to_span") + def test_enhanced_callback_successful_execution(self, mock_add_attributes, mock_get_span): + """Test enhanced_callback normal execution path""" + # Arrange + mock_span = Mock() + mock_span.is_recording.return_value = True + mock_get_span.return_value = mock_span + + original_decorate = Mock() + mock_callback = AsyncMock() + mock_message = Mock() + + # Create patched decorate function + patched_decorate = patch_callback_decorator_decorate(original_decorate) + + # Call patched decorate to get the enhanced callback + patched_decorate(Mock(), mock_callback) + + # Get the enhanced callback from the call + enhanced_callback = original_decorate.call_args[0][1] + + # Act + asyncio.run(enhanced_callback(mock_message)) + + # Assert + mock_get_span.assert_called_once() + mock_span.is_recording.assert_called_once() + mock_add_attributes.assert_called_once_with(mock_span, mock_callback) + mock_callback.assert_called_once_with(mock_message) + @patch("amazon.opentelemetry.distro.patches._aio_pika_patches.get_code_correlation_enabled_status") def test_apply_aio_pika_instrumentation_patches_disabled(self, mock_get_status): mock_get_status.return_value = False From 936c4b9a57a67cef311374f2df4d3a00a787ba7a Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 17 Oct 2025 14:01:48 -0700 Subject: [PATCH 3/4] Fix an old flaky unit test bug --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 77e29c2a7..e22c09ca8 100644 --- a/tox.ini +++ b/tox.ini @@ -22,8 +22,8 @@ deps = setenv = ; TODO: The two repos branches need manual updated over time, need to figure out a more sustainable solution. - CORE_REPO="git+https://github.com/open-telemetry/opentelemetry-python.git@release/v1.25.x-0.46bx" - CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@release/v1.25.x-0.46bx" + CORE_REPO="git+https://github.com/open-telemetry/opentelemetry-python.git@release/v1.31.x-0.52bx" + CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@release/v1.31.x-0.52bx" changedir = test-aws-opentelemetry-distro: aws-opentelemetry-distro/tests From ad7f7a82075a050d7b4db39f1ddae5b038ecfdb5 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 17 Oct 2025 14:09:39 -0700 Subject: [PATCH 4/4] Fix an old flaky unit test bug --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index e22c09ca8..720de7c4d 100644 --- a/tox.ini +++ b/tox.ini @@ -22,8 +22,8 @@ deps = setenv = ; TODO: The two repos branches need manual updated over time, need to figure out a more sustainable solution. - CORE_REPO="git+https://github.com/open-telemetry/opentelemetry-python.git@release/v1.31.x-0.52bx" - CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@release/v1.31.x-0.52bx" + CORE_REPO="git+https://github.com/open-telemetry/opentelemetry-python.git@release/v1.33.x-0.54bx" + CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@release/v1.33.x-0.54bx" changedir = test-aws-opentelemetry-distro: aws-opentelemetry-distro/tests