diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py index df066bca2..7cc5611f7 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py @@ -67,7 +67,7 @@ def apply_instrumentation_patches() -> None: from amazon.opentelemetry.distro.patches._starlette_patches import _apply_starlette_instrumentation_patches # Starlette auto-instrumentation v0.54b includes a strict dependency version check - # This restriction was removed in v1.34.0/0.55b0. Applying temporary patch for Genesis launch + # This restriction was removed in v1.34.0/0.55b0. Applying temporary patch for Bedrock AgentCore launch # TODO: Remove patch after syncing with upstream v1.34.0 or later _apply_starlette_instrumentation_patches() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py index b3bcd624b..385fb0b59 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py @@ -4,6 +4,8 @@ from logging import Logger, getLogger from typing import Collection +from amazon.opentelemetry.distro._utils import is_agent_observability_enabled + _logger: Logger = getLogger(__name__) @@ -18,6 +20,7 @@ def _apply_starlette_instrumentation_patches() -> None: """ try: # pylint: disable=import-outside-toplevel + from opentelemetry.instrumentation.asgi import OpenTelemetryMiddleware from opentelemetry.instrumentation.starlette import StarletteInstrumentor # Patch starlette dependencies version check @@ -28,6 +31,25 @@ def patched_instrumentation_dependencies(self) -> Collection[str]: # Apply the patch StarletteInstrumentor.instrumentation_dependencies = patched_instrumentation_dependencies + # pylint: disable=line-too-long + # Patch to exclude http receive/send ASGI event spans from Bedrock AgentCore, + # this Middleware instrumentation is injected internally by Starlette Instrumentor, see: + # https://github.com/open-telemetry/opentelemetry-python-contrib/blob/51da0a766e5d3cbc746189e10c9573163198cfcd/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py#L573 + # + # Issue for tracking a feature to customize this setting within Starlette: + # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3725 + if is_agent_observability_enabled(): + original_init = OpenTelemetryMiddleware.__init__ + + def patched_init(self, app, **kwargs): + original_init(self, app, **kwargs) + if hasattr(self, "exclude_receive_span"): + self.exclude_receive_span = True + if hasattr(self, "exclude_send_span"): + self.exclude_send_span = True + + OpenTelemetryMiddleware.__init__ = patched_init + _logger.debug("Successfully patched Starlette instrumentation_dependencies method") except Exception as exc: # pylint: disable=broad-except _logger.warning("Failed to apply Starlette instrumentation patches: %s", exc) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py index e0bbf4270..3de5f0bde 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py @@ -12,31 +12,63 @@ class TestStarlettePatch(TestCase): @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") def test_starlette_patch_applied_successfully(self, mock_logger): """Test that the Starlette instrumentation patch is applied successfully.""" - # Create a mock StarletteInstrumentor class - mock_instrumentor_class = MagicMock() - mock_instrumentor_class.__name__ = "StarletteInstrumentor" - - # Create a mock module - mock_starlette_module = MagicMock() - mock_starlette_module.StarletteInstrumentor = mock_instrumentor_class - - # Mock the import - with patch.dict("sys.modules", {"opentelemetry.instrumentation.starlette": mock_starlette_module}): - # Apply the patch - _apply_starlette_instrumentation_patches() - - # Verify the instrumentation_dependencies method was replaced - self.assertTrue(hasattr(mock_instrumentor_class, "instrumentation_dependencies")) - - # Test the patched method returns the expected value - mock_instance = MagicMock() - result = mock_instrumentor_class.instrumentation_dependencies(mock_instance) - self.assertEqual(result, ("starlette >= 0.13",)) - - # Verify logging - mock_logger.debug.assert_called_once_with( - "Successfully patched Starlette instrumentation_dependencies method" - ) + for agent_enabled in [True, False]: + with self.subTest(agent_enabled=agent_enabled): + with patch.dict("os.environ", {"AGENT_OBSERVABILITY_ENABLED": "true" if agent_enabled else "false"}): + # Create a mock StarletteInstrumentor class + mock_instrumentor_class = MagicMock() + mock_instrumentor_class.__name__ = "StarletteInstrumentor" + + def create_middleware_class(): + class MockMiddleware: + def __init__(self, app, **kwargs): + pass + + return MockMiddleware + + mock_middleware_class = create_middleware_class() + + mock_starlette_module = MagicMock() + mock_starlette_module.StarletteInstrumentor = mock_instrumentor_class + + mock_asgi_module = MagicMock() + mock_asgi_module.OpenTelemetryMiddleware = mock_middleware_class + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": mock_starlette_module, + "opentelemetry.instrumentation.asgi": mock_asgi_module, + }, + ): + # Apply the patch + _apply_starlette_instrumentation_patches() + + # Verify the instrumentation_dependencies method was replaced + self.assertTrue(hasattr(mock_instrumentor_class, "instrumentation_dependencies")) + + # Test the patched method returns the expected value + mock_instance = MagicMock() + result = mock_instrumentor_class.instrumentation_dependencies(mock_instance) + self.assertEqual(result, ("starlette >= 0.13",)) + + mock_middleware_instance = MagicMock() + mock_middleware_instance.exclude_receive_span = False + mock_middleware_instance.exclude_send_span = False + mock_middleware_class.__init__(mock_middleware_instance, "app") + + # Test middleware patching sets exclude flags + if agent_enabled: + self.assertTrue(mock_middleware_instance.exclude_receive_span) + self.assertTrue(mock_middleware_instance.exclude_send_span) + else: + self.assertFalse(mock_middleware_instance.exclude_receive_span) + self.assertFalse(mock_middleware_instance.exclude_send_span) + + # Verify logging + mock_logger.debug.assert_called_with( + "Successfully patched Starlette instrumentation_dependencies method" + ) @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") def test_starlette_patch_handles_import_error(self, mock_logger):