From c64b51ebdc477792632b5b0ea5311ba6c420f07f Mon Sep 17 00:00:00 2001 From: wangzlei Date: Thu, 2 Oct 2025 11:28:17 -0700 Subject: [PATCH 1/8] support code attributes for lambda --- .../distro/aws_opentelemetry_configurator.py | 2 +- .../distro/code_correlation/__init__.py | 12 ++-- .../code_correlation/test_code_correlation.py | 58 +++++++++---------- .../test_aws_opentelementry_configurator.py | 28 ++++----- .../instrumentation/aws_lambda/__init__.py | 23 ++++++++ 5 files changed, 73 insertions(+), 50 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py index 5cce1ced0..1eb183c3c 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py @@ -616,7 +616,7 @@ def _is_application_signals_runtime_enabled(): ) -def _get_code_correlation_enabled_status() -> Optional[bool]: +def get_code_correlation_enabled_status() -> Optional[bool]: """ Get the code correlation enabled status from environment variable. diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/code_correlation/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/code_correlation/__init__.py index 420145419..b4d7819b0 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/code_correlation/__init__.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/code_correlation/__init__.py @@ -21,7 +21,7 @@ CODE_LINE_NUMBER = "code.line.number" -def _add_code_attributes_to_span(span, func: Callable[..., Any]) -> None: +def add_code_attributes_to_span(span, func: Callable[..., Any]) -> None: """ Add code-related attributes to a span based on a Python function. @@ -68,7 +68,7 @@ def _add_code_attributes_to_span(span, func: Callable[..., Any]) -> None: pass -def add_code_attributes_to_span(func: Callable[..., Any]) -> Callable[..., Any]: +def record_code_attributes(func: Callable[..., Any]) -> Callable[..., Any]: """ Decorator to automatically add code attributes to the current OpenTelemetry span. @@ -81,12 +81,12 @@ def add_code_attributes_to_span(func: Callable[..., Any]) -> Callable[..., Any]: This decorator supports both synchronous and asynchronous functions. Usage: - @add_code_attributes_to_span + @record_code_attributes def my_sync_function(): # Sync function implementation pass - @add_code_attributes_to_span + @record_code_attributes async def my_async_function(): # Async function implementation pass @@ -109,7 +109,7 @@ async def async_wrapper(*args, **kwargs): try: current_span = trace.get_current_span() if current_span: - _add_code_attributes_to_span(current_span, func) + add_code_attributes_to_span(current_span, func) except Exception: # pylint: disable=broad-exception-caught # Silently handle any unexpected errors pass @@ -126,7 +126,7 @@ def sync_wrapper(*args, **kwargs): try: current_span = trace.get_current_span() if current_span: - _add_code_attributes_to_span(current_span, func) + add_code_attributes_to_span(current_span, func) except Exception: # pylint: disable=broad-exception-caught # Silently handle any unexpected errors pass diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/code_correlation/test_code_correlation.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/code_correlation/test_code_correlation.py index f75d7e0f3..ba787cd8e 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/code_correlation/test_code_correlation.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/code_correlation/test_code_correlation.py @@ -9,8 +9,8 @@ CODE_FILE_PATH, CODE_FUNCTION_NAME, CODE_LINE_NUMBER, - _add_code_attributes_to_span, add_code_attributes_to_span, + record_code_attributes, ) from opentelemetry.trace import Span @@ -26,7 +26,7 @@ def test_constants_values(self): class TestAddCodeAttributesToSpan(TestCase): - """Test the _add_code_attributes_to_span function.""" + """Test the add_code_attributes_to_span function.""" def setUp(self): """Set up test fixtures.""" @@ -39,7 +39,7 @@ def test_add_code_attributes_to_recording_span(self): def test_function(): pass - _add_code_attributes_to_span(self.mock_span, test_function) + add_code_attributes_to_span(self.mock_span, test_function) # Verify function name attribute is set self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_function") @@ -59,7 +59,7 @@ def test_add_code_attributes_to_non_recording_span(self): def test_function(): pass - _add_code_attributes_to_span(self.mock_span, test_function) + add_code_attributes_to_span(self.mock_span, test_function) # Verify no attributes are set self.mock_span.set_attribute.assert_not_called() @@ -71,7 +71,7 @@ def test_add_code_attributes_function_without_code(self): mock_func.__name__ = "mock_function" delattr(mock_func, "__code__") - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify only function name attribute is set self.mock_span.set_attribute.assert_called_once_with(CODE_FUNCTION_NAME, "mock_function") @@ -79,7 +79,7 @@ def test_add_code_attributes_function_without_code(self): def test_add_code_attributes_builtin_function(self): """Test handling of built-in functions.""" # Use a built-in function like len - _add_code_attributes_to_span(self.mock_span, len) + add_code_attributes_to_span(self.mock_span, len) # Verify only function name attribute is set self.mock_span.set_attribute.assert_called_once_with(CODE_FUNCTION_NAME, "len") @@ -93,7 +93,7 @@ def test_add_code_attributes_function_without_name(self): mock_func.__code__.co_filename = "/test/file.py" mock_func.__code__.co_firstlineno = 10 - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify function name uses str() representation self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, str(mock_func)) @@ -110,7 +110,7 @@ def test_add_code_attributes_exception_handling(self): mock_func.__code__.co_firstlineno = MagicMock(side_effect=Exception("Test exception")) # This should not raise an exception - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify function name and file path are still set self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_func") @@ -125,7 +125,7 @@ def test_function(): pass # This should not raise an exception - _add_code_attributes_to_span(self.mock_span, test_function) + add_code_attributes_to_span(self.mock_span, test_function) # Verify no attributes are set due to exception self.mock_span.set_attribute.assert_not_called() @@ -143,7 +143,7 @@ def test_add_code_attributes_co_filename_exception(self): mock_func.__code__ = mock_code # This should not raise an exception - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify function name and line number are still set, but not file path self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_func") @@ -165,7 +165,7 @@ def test_add_code_attributes_co_firstlineno_exception(self): mock_func.__code__ = mock_code # This should not raise an exception - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify function name and file path are still set, but not line number self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_func") @@ -187,7 +187,7 @@ def test_add_code_attributes_co_filename_type_error(self): mock_func.__code__ = mock_code # This should not raise an exception - _add_code_attributes_to_span(self.mock_span, mock_func) + add_code_attributes_to_span(self.mock_span, mock_func) # Verify function name and line number are still set, but not file path self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_func") @@ -197,8 +197,8 @@ def test_add_code_attributes_co_filename_type_error(self): self.mock_span.set_attribute.assert_any_call(CODE_FILE_PATH, MagicMock()) -class TestAddCodeAttributesToSpanDecorator(TestCase): - """Test the add_code_attributes_to_span decorator.""" +class TestRecordCodeAttributesDecorator(TestCase): + """Test the record_code_attributes decorator.""" def setUp(self): """Set up test fixtures.""" @@ -210,7 +210,7 @@ def test_decorator_sync_function(self, mock_get_current_span): """Test decorator with synchronous function.""" mock_get_current_span.return_value = self.mock_span - @add_code_attributes_to_span + @record_code_attributes def test_sync_function(arg1, arg2=None): return f"sync result: {arg1}, {arg2}" @@ -228,7 +228,7 @@ def test_decorator_async_function(self, mock_get_current_span): """Test decorator with asynchronous function.""" mock_get_current_span.return_value = self.mock_span - @add_code_attributes_to_span + @record_code_attributes async def test_async_function(arg1, arg2=None): return f"async result: {arg1}, {arg2}" @@ -251,7 +251,7 @@ def test_decorator_no_current_span(self, mock_get_current_span): """Test decorator when there's no current span.""" mock_get_current_span.return_value = None - @add_code_attributes_to_span + @record_code_attributes def test_function(): return "test result" @@ -269,7 +269,7 @@ def test_decorator_exception_handling(self, mock_get_current_span): """Test decorator handles exceptions gracefully.""" mock_get_current_span.side_effect = Exception("Test exception") - @add_code_attributes_to_span + @record_code_attributes def test_function(): return "test result" @@ -282,7 +282,7 @@ def test_function(): def test_decorator_preserves_function_metadata(self): """Test that decorator preserves original function metadata.""" - @add_code_attributes_to_span + @record_code_attributes def test_function(): """Test function docstring.""" return "test result" @@ -303,8 +303,8 @@ async def async_func(): pass # Apply decorator to both - decorated_sync = add_code_attributes_to_span(sync_func) - decorated_async = add_code_attributes_to_span(async_func) + decorated_sync = record_code_attributes(sync_func) + decorated_async = record_code_attributes(async_func) # Check that sync function returns a regular function self.assertFalse(asyncio.iscoroutinefunction(decorated_sync)) @@ -317,7 +317,7 @@ def test_decorator_with_function_that_raises_exception(self, mock_get_current_sp """Test decorator with function that raises exception.""" mock_get_current_span.return_value = self.mock_span - @add_code_attributes_to_span + @record_code_attributes def test_function(): raise ValueError("Test function exception") @@ -333,7 +333,7 @@ def test_decorator_with_async_function_that_raises_exception(self, mock_get_curr """Test decorator with async function that raises exception.""" mock_get_current_span.return_value = self.mock_span - @add_code_attributes_to_span + @record_code_attributes async def test_async_function(): raise ValueError("Test async function exception") @@ -349,15 +349,15 @@ async def test_async_function(): # Verify span attributes were still set before exception self.mock_span.set_attribute.assert_any_call(CODE_FUNCTION_NAME, "test_async_function") - @patch("amazon.opentelemetry.distro.code_correlation._add_code_attributes_to_span") + @patch("amazon.opentelemetry.distro.code_correlation.add_code_attributes_to_span") @patch("amazon.opentelemetry.distro.code_correlation.trace.get_current_span") def test_decorator_internal_exception_handling_sync(self, mock_get_current_span, mock_add_attributes): """Test that decorator handles internal exceptions gracefully in sync function.""" mock_get_current_span.return_value = self.mock_span - # Make _add_code_attributes_to_span raise an exception + # Make add_code_attributes_to_span raise an exception mock_add_attributes.side_effect = Exception("Internal exception") - @add_code_attributes_to_span + @record_code_attributes def test_function(): return "test result" @@ -367,15 +367,15 @@ def test_function(): # Verify the function still works correctly despite internal exception self.assertEqual(result, "test result") - @patch("amazon.opentelemetry.distro.code_correlation._add_code_attributes_to_span") + @patch("amazon.opentelemetry.distro.code_correlation.add_code_attributes_to_span") @patch("amazon.opentelemetry.distro.code_correlation.trace.get_current_span") def test_decorator_internal_exception_handling_async(self, mock_get_current_span, mock_add_attributes): """Test that decorator handles internal exceptions gracefully in async function.""" mock_get_current_span.return_value = self.mock_span - # Make _add_code_attributes_to_span raise an exception + # Make add_code_attributes_to_span raise an exception mock_add_attributes.side_effect = Exception("Internal exception") - @add_code_attributes_to_span + @record_code_attributes async def test_async_function(): return "async test result" diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py index 7ab15a626..93e363f3e 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py @@ -40,12 +40,12 @@ _export_unsampled_span_for_agent_observability, _export_unsampled_span_for_lambda, _fetch_logs_header, - _get_code_correlation_enabled_status, _init_logging, _is_application_signals_enabled, _is_application_signals_runtime_enabled, _is_defer_to_workers_enabled, _is_wsgi_master_process, + get_code_correlation_enabled_status, ) from amazon.opentelemetry.distro.aws_opentelemetry_distro import AwsOpenTelemetryDistro from amazon.opentelemetry.distro.aws_span_metrics_processor import AwsSpanMetricsProcessor @@ -1428,60 +1428,60 @@ def test_create_emf_exporter_cloudwatch_exporter_import_error( mock_logger.error.assert_called_once() def test_get_code_correlation_enabled_status(self): - """Test _get_code_correlation_enabled_status function with various environment variable values""" + """Test get_code_correlation_enabled_status function with various environment variable values""" # Test when environment variable is not set (default state) os.environ.pop(CODE_CORRELATION_ENABLED_CONFIG, None) - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertIsNone(result) # Test when environment variable is set to 'true' (case insensitive) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "true" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertTrue(result) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "TRUE" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertTrue(result) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "True" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertTrue(result) # Test when environment variable is set to 'false' (case insensitive) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "false" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertFalse(result) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "FALSE" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertFalse(result) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "False" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertFalse(result) # Test with leading/trailing whitespace os.environ[CODE_CORRELATION_ENABLED_CONFIG] = " true " - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertTrue(result) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = " false " - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertFalse(result) # Test invalid values (should return None and log warning) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "invalid" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertIsNone(result) # Test another invalid value os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "yes" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertIsNone(result) # Test empty string (invalid) os.environ[CODE_CORRELATION_ENABLED_CONFIG] = "" - result = _get_code_correlation_enabled_status() + result = get_code_correlation_enabled_status() self.assertIsNone(result) # Clean up diff --git a/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py index 23f005f6a..fc8418acc 100644 --- a/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -91,6 +91,18 @@ def custom_event_context_extractor(lambda_event): from opentelemetry.trace import Span, SpanKind, TracerProvider, get_tracer, get_tracer_provider from opentelemetry.trace.status import Status, StatusCode +# Import code correlation functionality +try: + from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status + from amazon.opentelemetry.distro.code_correlation import add_code_attributes_to_span +except ImportError: + # If code correlation module is not available, define no-op functions + def add_code_attributes_to_span(span, func): + pass + + def get_code_correlation_enabled_status(): + return None + logger = logging.getLogger(__name__) _HANDLER = "_HANDLER" @@ -303,6 +315,17 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches account_id, ) + # Add code-level information attributes to the span if enabled + if get_code_correlation_enabled_status() is True: + try: + add_code_attributes_to_span(span, call_wrapped) + except Exception as exc: + # Log but don't fail the instrumentation + logger.debug( + "Failed to add code attributes to lambda span: %s", + str(exc) + ) + exception = None result = None try: From d4163816a6a00fbcdb51649b5e8ff9a7ebf6f536 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Thu, 2 Oct 2025 21:34:36 -0700 Subject: [PATCH 2/8] support code attributes for flask --- .flake8 | 3 + .../distro/patches/_flask_patches.py | 143 ++++++++ .../distro/patches/_instrumentation_patch.py | 7 + .../distro/patches/test_flask_patches.py | 307 ++++++++++++++++++ .../test_aws_opentelementry_configurator.py | 7 + .../instrumentation/aws_lambda/__init__.py | 4 +- 6 files changed, 469 insertions(+), 2 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py diff --git a/.flake8 b/.flake8 index e55e479dd..10d984636 100644 --- a/.flake8 +++ b/.flake8 @@ -24,3 +24,6 @@ exclude = mock_collector_service_pb2.py mock_collector_service_pb2.pyi mock_collector_service_pb2_grpc.py + lambda-layer/terraform/lambda/.terraform + lambda-layer/sample-apps/build + samples diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py new file mode 100644 index 000000000..5fbb638a0 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License. + +from logging import getLogger + +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status + +_logger = getLogger(__name__) + + +def _apply_flask_instrumentation_patches() -> None: + """Flask instrumentation patches + + Applies patches to provide code attributes support for Flask instrumentation. + This patches the Flask instrumentation to automatically add code attributes + to spans by decorating view functions with current_span_code_attributes. + """ + if get_code_correlation_enabled_status() is True: + _apply_flask_code_attributes_patch() + + +def _apply_flask_code_attributes_patch() -> None: + """Flask instrumentation patch for code attributes + + This patch modifies the Flask instrumentation to automatically apply + the current_span_code_attributes decorator to all view functions when + the Flask app is instrumented. + + The patch: + 1. Imports current_span_code_attributes decorator from AWS distro utils + 2. Hooks Flask's add_url_rule method during _instrument by patching Flask class + 3. Hooks Flask's dispatch_request method to handle deferred view function binding + 4. Automatically decorates view functions as they are registered or at request time + 5. Adds code.function.name, code.file.path, and code.line.number to spans + 6. Provides cleanup during _uninstrument + """ + try: + # Import Flask instrumentation classes and AWS decorator + import flask + + from amazon.opentelemetry.distro.code_correlation import record_code_attributes + from opentelemetry.instrumentation.flask import FlaskInstrumentor + + # Store the original _instrument and _uninstrument methods + original_instrument = FlaskInstrumentor._instrument + original_uninstrument = FlaskInstrumentor._uninstrument + + # Store reference to original Flask methods + original_flask_add_url_rule = flask.Flask.add_url_rule + original_flask_dispatch_request = flask.Flask.dispatch_request + + def _decorate_view_func(view_func, endpoint=None): + """Helper function to decorate a view function with code attributes.""" + try: + if view_func and callable(view_func): + # Check if function is already decorated (avoid double decoration) + if not hasattr(view_func, "_current_span_code_attributes_decorated"): + # Apply decorator + decorated_view_func = record_code_attributes(view_func) + # Mark as decorated to avoid double decoration + decorated_view_func._current_span_code_attributes_decorated = True + decorated_view_func._original_view_func = view_func + return decorated_view_func + return view_func + except Exception as e: + _logger.warning("Failed to apply code attributes decorator to view function %s: %s", endpoint, e) + return view_func + + def _wrapped_add_url_rule(self, rule, endpoint=None, view_func=None, **options): + """Wrapped Flask.add_url_rule method with code attributes decoration.""" + # Apply decorator to view function if available + if view_func: + view_func = _decorate_view_func(view_func, endpoint) + + return original_flask_add_url_rule(self, rule, endpoint, view_func, **options) + + def _wrapped_dispatch_request(self): + """Wrapped Flask.dispatch_request method to handle deferred view function binding.""" + try: + # Get the current request context + from flask import request + + # Check if there's an endpoint for this request + endpoint = request.endpoint + if endpoint and endpoint in self.view_functions: + view_func = self.view_functions[endpoint] + + # Check if the view function needs decoration + if view_func and callable(view_func): + if not hasattr(view_func, "_current_span_code_attributes_decorated"): + # Decorate the view function and replace it in view_functions + decorated_view_func = _decorate_view_func(view_func, endpoint) + if decorated_view_func != view_func: + self.view_functions[endpoint] = decorated_view_func + _logger.debug( + "Applied code attributes decorator to deferred view function for endpoint: %s", + endpoint, + ) + + except Exception as e: + _logger.warning("Failed to process deferred view function decoration: %s", e) + + # Call the original dispatch_request method + return original_flask_dispatch_request(self) + + def patched_instrument(self, **kwargs): + """Patched _instrument method with Flask method wrapping""" + # Store original methods if not already stored + if not hasattr(self, "_original_flask_add_url_rule"): + self._original_flask_add_url_rule = flask.Flask.add_url_rule + self._original_flask_dispatch_request = flask.Flask.dispatch_request + + # Wrap Flask methods with code attributes decoration + flask.Flask.add_url_rule = _wrapped_add_url_rule + flask.Flask.dispatch_request = _wrapped_dispatch_request + + # Call the original _instrument method + original_instrument(self, **kwargs) + + def patched_uninstrument(self, **kwargs): + """Patched _uninstrument method with Flask method restoration""" + # Call the original _uninstrument method first + original_uninstrument(self, **kwargs) + + # Restore original Flask methods if they exist + if hasattr(self, "_original_flask_add_url_rule"): + try: + flask.Flask.add_url_rule = self._original_flask_add_url_rule + flask.Flask.dispatch_request = self._original_flask_dispatch_request + delattr(self, "_original_flask_add_url_rule") + delattr(self, "_original_flask_dispatch_request") + except Exception as e: + _logger.warning("Failed to restore original Flask methods: %s", e) + + # Apply the patches to FlaskInstrumentor + FlaskInstrumentor._instrument = patched_instrument + FlaskInstrumentor._uninstrument = patched_uninstrument + + _logger.debug("Flask instrumentation code attributes patch applied successfully") + + except Exception as e: + _logger.warning("Failed to apply Flask code attributes patch: %s", e) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py index 7cc5611f7..46b80abd1 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py @@ -71,6 +71,13 @@ def apply_instrumentation_patches() -> None: # TODO: Remove patch after syncing with upstream v1.34.0 or later _apply_starlette_instrumentation_patches() + if is_installed("flask"): + # pylint: disable=import-outside-toplevel + # Delay import to only occur if patches is safe to apply (e.g. the instrumented library is installed). + from amazon.opentelemetry.distro.patches._flask_patches import _apply_flask_instrumentation_patches + + _apply_flask_instrumentation_patches() + # No need to check if library is installed as this patches opentelemetry.sdk, # which must be installed for the distro to work at all. _apply_resource_detector_patches() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py new file mode 100644 index 000000000..3a9f1896d --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py @@ -0,0 +1,307 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + +import flask +from werkzeug.test import Client +from werkzeug.wrappers import Response + +from amazon.opentelemetry.distro.patches._flask_patches import _apply_flask_instrumentation_patches +from opentelemetry import trace +from opentelemetry.instrumentation.flask import FlaskInstrumentor +from opentelemetry.test.test_base import TestBase + +# Store truly original Flask methods at module level before any patches +_ORIGINAL_FLASK_ADD_URL_RULE = flask.Flask.add_url_rule +_ORIGINAL_FLASK_DISPATCH_REQUEST = flask.Flask.dispatch_request + +# Store original FlaskInstrumentor methods before any patches +_ORIGINAL_FLASK_INSTRUMENTOR_INSTRUMENT = FlaskInstrumentor._instrument +_ORIGINAL_FLASK_INSTRUMENTOR_UNINSTRUMENT = FlaskInstrumentor._uninstrument + + +class TestFlaskPatchesRealApp(TestBase): + """Test Flask patches using a real Flask application.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + # Always start with clean Flask methods + flask.Flask.add_url_rule = _ORIGINAL_FLASK_ADD_URL_RULE + flask.Flask.dispatch_request = _ORIGINAL_FLASK_DISPATCH_REQUEST + + # Always start with clean FlaskInstrumentor methods + FlaskInstrumentor._instrument = _ORIGINAL_FLASK_INSTRUMENTOR_INSTRUMENT + FlaskInstrumentor._uninstrument = _ORIGINAL_FLASK_INSTRUMENTOR_UNINSTRUMENT + + # Create real Flask app + self.app = flask.Flask(__name__) + + # Add test routes + @self.app.route("/hello/") + def hello(name): + return f"Hello {name}!" + + @self.app.route("/error") + def error_endpoint(): + raise ValueError("Test error") + + @self.app.route("/simple") + def simple(): + return "OK" + + # Create test client + self.client = Client(self.app, Response) + + def tearDown(self): + """Clean up after tests.""" + super().tearDown() + + # Always restore original Flask methods to avoid contamination between tests + flask.Flask.add_url_rule = _ORIGINAL_FLASK_ADD_URL_RULE + flask.Flask.dispatch_request = _ORIGINAL_FLASK_DISPATCH_REQUEST + + # Always restore original FlaskInstrumentor methods to avoid contamination between tests + FlaskInstrumentor._instrument = _ORIGINAL_FLASK_INSTRUMENTOR_INSTRUMENT + FlaskInstrumentor._uninstrument = _ORIGINAL_FLASK_INSTRUMENTOR_UNINSTRUMENT + + # Clear any stored class attributes from patches + for attr_name in list(vars(FlaskInstrumentor).keys()): + if attr_name.startswith("_original_flask_"): + delattr(FlaskInstrumentor, attr_name) + + # CRITICAL: Clear instance attributes from the singleton FlaskInstrumentor instance + # FlaskInstrumentor is a singleton, so we need to clean up the instance attributes + instrumentor_instance = FlaskInstrumentor() + instance_attrs_to_remove = [] + for attr_name in dir(instrumentor_instance): + if attr_name.startswith("_original_flask_"): + instance_attrs_to_remove.append(attr_name) + + for attr_name in instance_attrs_to_remove: + if hasattr(instrumentor_instance, attr_name): + delattr(instrumentor_instance, attr_name) + + # Clean up instrumentor - use global uninstrument + try: + FlaskInstrumentor().uninstrument() + except Exception: # pylint: disable=broad-exception-caught + pass + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_with_real_app(self, mock_get_status): + """Test Flask patches with real Flask app covering various scenarios.""" + mock_get_status.return_value = True + + # Store original Flask methods - use the module level constants + original_add_url_rule = _ORIGINAL_FLASK_ADD_URL_RULE + original_dispatch_request = _ORIGINAL_FLASK_DISPATCH_REQUEST + + # Apply patches FIRST + _apply_flask_instrumentation_patches() + + # Verify that get_status was called + mock_get_status.assert_called_once() + + # Create instrumentor and manually call _instrument to trigger Flask method wrapping + instrumentor = FlaskInstrumentor() + instrumentor._instrument() + + # Check if Flask methods were wrapped by patches + current_add_url_rule = flask.Flask.add_url_rule + current_dispatch_request = flask.Flask.dispatch_request + + # Test that Flask methods are actually wrapped - this is the core functionality + self.assertNotEqual(current_add_url_rule, original_add_url_rule, "Flask.add_url_rule should be wrapped") + self.assertNotEqual( + current_dispatch_request, original_dispatch_request, "Flask.dispatch_request should be wrapped" + ) + + # Test a request to trigger the patches + instrumentor.instrument_app(self.app) + resp = self.client.get("/hello/world") + self.assertEqual(200, resp.status_code) + + # Test uninstrumentation - this should restore original Flask methods + instrumentor._uninstrument() + + # Check if Flask methods were restored + restored_add_url_rule = flask.Flask.add_url_rule + restored_dispatch_request = flask.Flask.dispatch_request + + # Methods should be restored to original after uninstrument + self.assertEqual(restored_add_url_rule, original_add_url_rule, "Flask.add_url_rule should be restored") + self.assertEqual( + restored_dispatch_request, original_dispatch_request, "Flask.dispatch_request should be restored" + ) + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_disabled(self, mock_get_status): + """Test Flask patches when code correlation is disabled.""" + mock_get_status.return_value = False + + # Apply patches (should not modify anything) + _apply_flask_instrumentation_patches() + + # Instrument the app normally + instrumentor = FlaskInstrumentor() + instrumentor.instrument_app(self.app) + + # Verify that get_status was called + mock_get_status.assert_called_once() + + # Make a request + resp = self.client.get("/hello/world") + self.assertEqual(200, resp.status_code) + self.assertEqual([b"Hello world!"], list(resp.response)) + + # Check spans were still generated (normal instrumentation) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.name, "GET /hello/") + self.assertEqual(span.kind, trace.SpanKind.SERVER) + + # Clean up - avoid Flask instrumentation issues + try: + instrumentor.uninstrument_app(self.app) + except Exception: + pass # Flask instrumentation cleanup may fail, that's ok + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_import_error_handling(self, mock_get_status): + """Test Flask patches with import errors.""" + mock_get_status.return_value = True + + # Test that patches handle import errors gracefully by mocking sys.modules + import sys + + original_modules = sys.modules.copy() + + try: + # Remove flask from sys.modules to simulate import error + if "flask" in sys.modules: + del sys.modules["flask"] + + # Should not raise exception even with missing flask + _apply_flask_instrumentation_patches() + + # Verify status was checked + mock_get_status.assert_called_once() + + finally: + # Restore original modules + sys.modules.clear() + sys.modules.update(original_modules) + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_view_function_decoration(self, mock_get_status): + """Test Flask patches view function decoration edge cases.""" + mock_get_status.return_value = True + + # Create instrumentor and apply patches + instrumentor = FlaskInstrumentor() + instrumentor.instrument_app(self.app) + _apply_flask_instrumentation_patches() + instrumentor._instrument() + + # Test adding routes with None view_func (edge case) + try: + self.app.add_url_rule("/test_none", "test_none", None) + except Exception: + pass # Expected to handle gracefully + + # Test adding routes with non-callable view_func + try: + self.app.add_url_rule("/test_string", "test_string", "not_callable") + except Exception: + pass # Expected to handle gracefully + + # Test route with lambda (should be decorated) + def lambda_func(): + return "lambda response" + + self.app.add_url_rule("/test_lambda", "test_lambda", lambda_func) + + # Clean up - don't call uninstrument_app to avoid Flask instrumentation issues + pass + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_dispatch_request_coverage(self, mock_get_status): + """Test Flask patches dispatch_request method coverage.""" + mock_get_status.return_value = True + + # Create a special app with deferred view function binding + test_app = flask.Flask(__name__) + + # Add route after creating app but before applying patches + @test_app.route("/deferred") + def deferred_view(): + return "deferred" + + # Create instrumentor and apply patches + instrumentor = FlaskInstrumentor() + _apply_flask_instrumentation_patches() + instrumentor._instrument() + instrumentor.instrument_app(test_app) + + # Create test client and make request to trigger dispatch_request + client = Client(test_app, Response) + resp = client.get("/deferred") + self.assertEqual(200, resp.status_code) + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_uninstrument_error_handling(self, mock_get_status): + """Test Flask patches uninstrument error handling.""" + mock_get_status.return_value = True + + # Create instrumentor and apply patches + instrumentor = FlaskInstrumentor() + _apply_flask_instrumentation_patches() + instrumentor._instrument() + + # Manually break the stored references to trigger error handling + if hasattr(instrumentor, "_original_flask_add_url_rule"): + # Set invalid values to trigger exceptions during restoration + instrumentor._original_flask_add_url_rule = None + instrumentor._original_flask_dispatch_request = None + + # This should trigger error handling in patched_uninstrument + try: + instrumentor._uninstrument() + except Exception: + pass # Expected to handle gracefully + + @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") + def test_flask_patches_code_correlation_import_error(self, mock_get_status): + """Test Flask patches when code_correlation import fails.""" + mock_get_status.return_value = True + + # Mock import error for code_correlation module + import sys + + original_modules = sys.modules.copy() + + try: + # Remove code_correlation module to simulate import error + modules_to_remove = [ + "amazon.opentelemetry.distro.code_correlation", + "amazon.opentelemetry.distro.code_correlation.record_code_attributes", + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + # Create instrumentor and apply patches - should handle import error gracefully + instrumentor = FlaskInstrumentor() + _apply_flask_instrumentation_patches() + + # Try to trigger the patched methods + instrumentor._instrument() + + finally: + # Restore original modules + sys.modules.clear() + sys.modules.update(original_modules) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py index 93e363f3e..4fcd71cb5 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py @@ -94,6 +94,13 @@ class TestAwsOpenTelemetryConfigurator(TestCase): @classmethod def setUpClass(cls): + # Store original environment variables to restore later + cls._original_env = {} + for key in list(os.environ.keys()): + if key.startswith("OTEL_"): + cls._original_env[key] = os.environ[key] + del os.environ[key] + # Run AwsOpenTelemetryDistro to set up environment, then validate expected env values. aws_open_telemetry_distro: AwsOpenTelemetryDistro = AwsOpenTelemetryDistro() aws_open_telemetry_distro.configure(apply_patches=False) diff --git a/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py index fc8418acc..b88e65438 100644 --- a/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/lambda-layer/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -99,7 +99,7 @@ def custom_event_context_extractor(lambda_event): # If code correlation module is not available, define no-op functions def add_code_attributes_to_span(span, func): pass - + def get_code_correlation_enabled_status(): return None @@ -322,7 +322,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches except Exception as exc: # Log but don't fail the instrumentation logger.debug( - "Failed to add code attributes to lambda span: %s", + "Failed to add code attributes to lambda span: %s", str(exc) ) From cae3f4b30f72ca042bd3c55476305aa2fed8b4b8 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 3 Oct 2025 21:07:36 -0700 Subject: [PATCH 3/8] support code attributes for fastapi and starlette --- .../distro/patches/_fastapi_patches.py | 117 ++++ .../distro/patches/_flask_patches.py | 16 +- .../distro/patches/_instrumentation_patch.py | 7 + .../distro/patches/_starlette_patches.py | 104 +++- .../distro/patches/test_fastapi_patches.py | 229 +++++++ .../distro/patches/test_starlette_patches.py | 577 ++++++++++++++++-- dev-requirements.txt | 3 + 7 files changed, 1006 insertions(+), 47 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_fastapi_patches.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py new file mode 100644 index 000000000..f3ee8dbca --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py @@ -0,0 +1,117 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License. + +from logging import getLogger + +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status + +_logger = getLogger(__name__) + + +def _apply_fastapi_instrumentation_patches() -> None: + """FastAPI instrumentation patches + + Applies patches to provide code attributes support for FastAPI instrumentation. + This patches the Flask instrumentation to automatically add code attributes + to spans by decorating view functions with current_span_code_attributes. + """ + if get_code_correlation_enabled_status() is True: + _apply_fastapi_code_attributes_patch() + + +def _apply_fastapi_code_attributes_patch() -> None: + """FastAPI instrumentation patch for code attributes + + This patch modifies the FastAPI instrumentation to automatically apply + the current_span_code_attributes decorator to all endpoint functions when + the FastAPI app is instrumented. + + The patch: + 1. Imports current_span_code_attributes decorator from AWS distro utils + 2. Hooks FastAPI's APIRouter.add_api_route method during instrumentation + 3. Automatically decorates endpoint functions as they are registered + 4. Adds code.function.name, code.file.path, and code.line.number to spans + 5. Provides cleanup during uninstrumentation + """ + try: + # Import FastAPI instrumentation classes and AWS decorator + from fastapi import routing + + from amazon.opentelemetry.distro.code_correlation import record_code_attributes + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + + # Store the original _instrument and _uninstrument methods + original_instrument = FastAPIInstrumentor._instrument + original_uninstrument = FastAPIInstrumentor._uninstrument + + def _wrapped_add_api_route(original_add_api_route_method): + """Wrapper for APIRouter.add_api_route method.""" + + def wrapper(self, *args, **kwargs): + # Apply current_span_code_attributes decorator to endpoint function + try: + # Get endpoint function from args or kwargs + endpoint = None + if len(args) >= 2: + endpoint = args[1] + else: + endpoint = kwargs.get("endpoint") + + if endpoint and callable(endpoint): + # Check if function is already decorated (avoid double decoration) + if not hasattr(endpoint, "_current_span_code_attributes_decorated"): + # Apply decorator + decorated_endpoint = record_code_attributes(endpoint) + # Mark as decorated to avoid double decoration + decorated_endpoint._current_span_code_attributes_decorated = True + decorated_endpoint._original_endpoint = endpoint + + # Replace endpoint in args or kwargs + if len(args) >= 2: + args = list(args) + args[1] = decorated_endpoint + args = tuple(args) + elif "endpoint" in kwargs: + kwargs["endpoint"] = decorated_endpoint + + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply code attributes decorator to endpoint: %s", exc) + + return original_add_api_route_method(self, *args, **kwargs) + + return wrapper + + def patched_instrument(self, **kwargs): + """Patched _instrument method with APIRouter.add_api_route wrapping""" + # Store original add_api_route method if not already stored + if not hasattr(self, "_original_apirouter"): + self._original_apirouter = routing.APIRouter.add_api_route + + # Wrap APIRouter.add_api_route with code attributes decoration + routing.APIRouter.add_api_route = _wrapped_add_api_route(self._original_apirouter) + + # Call the original _instrument method + original_instrument(self, **kwargs) + + def patched_uninstrument(self, **kwargs): + """Patched _uninstrument method with APIRouter.add_api_route restoration""" + # Call the original _uninstrument method first + original_uninstrument(self, **kwargs) + + # Restore original APIRouter.add_api_route method if it exists + if hasattr(self, "_original_apirouter"): + try: + routing.APIRouter.add_api_route = self._original_apirouter + delattr(self, "_original_apirouter") + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to restore original APIRouter.add_api_route method: %s", exc) + + # Apply the patches to FastAPIInstrumentor + FastAPIInstrumentor._instrument = patched_instrument + FastAPIInstrumentor._uninstrument = patched_uninstrument + + _logger.debug("FastAPI instrumentation code attributes patch applied successfully") + + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply FastAPI code attributes patch: %s", exc) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py index 5fbb638a0..b25c34b50 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py @@ -63,8 +63,8 @@ def _decorate_view_func(view_func, endpoint=None): decorated_view_func._original_view_func = view_func return decorated_view_func return view_func - except Exception as e: - _logger.warning("Failed to apply code attributes decorator to view function %s: %s", endpoint, e) + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply code attributes decorator to view function %s: %s", endpoint, exc) return view_func def _wrapped_add_url_rule(self, rule, endpoint=None, view_func=None, **options): @@ -98,8 +98,8 @@ def _wrapped_dispatch_request(self): endpoint, ) - except Exception as e: - _logger.warning("Failed to process deferred view function decoration: %s", e) + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to process deferred view function decoration: %s", exc) # Call the original dispatch_request method return original_flask_dispatch_request(self) @@ -130,8 +130,8 @@ def patched_uninstrument(self, **kwargs): flask.Flask.dispatch_request = self._original_flask_dispatch_request delattr(self, "_original_flask_add_url_rule") delattr(self, "_original_flask_dispatch_request") - except Exception as e: - _logger.warning("Failed to restore original Flask methods: %s", e) + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to restore original Flask methods: %s", exc) # Apply the patches to FlaskInstrumentor FlaskInstrumentor._instrument = patched_instrument @@ -139,5 +139,5 @@ def patched_uninstrument(self, **kwargs): _logger.debug("Flask instrumentation code attributes patch applied successfully") - except Exception as e: - _logger.warning("Failed to apply Flask code attributes patch: %s", e) + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply Flask code attributes patch: %s", exc) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py index 46b80abd1..2f1f4bba5 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_instrumentation_patch.py @@ -78,6 +78,13 @@ def apply_instrumentation_patches() -> None: _apply_flask_instrumentation_patches() + if is_installed("fastapi"): + # pylint: disable=import-outside-toplevel + # Delay import to only occur if patches is safe to apply (e.g. the instrumented library is installed). + from amazon.opentelemetry.distro.patches._fastapi_patches import _apply_fastapi_instrumentation_patches + + _apply_fastapi_instrumentation_patches() + # No need to check if library is installed as this patches opentelemetry.sdk, # which must be installed for the distro to work at all. _apply_resource_detector_patches() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py index 385fb0b59..00d0d89b1 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py @@ -5,15 +5,26 @@ from typing import Collection from amazon.opentelemetry.distro._utils import is_agent_observability_enabled +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import get_code_correlation_enabled_status _logger: Logger = getLogger(__name__) +def _apply_starlette_instrumentation_patches() -> None: + """Apply patches to the Starlette instrumentation. + + This applies both version compatibility patches and code attributes support. + """ + _apply_starlette_version_patches() + if get_code_correlation_enabled_status() is True: + _apply_starlette_code_attributes_patch() + + # Upstream fix available in OpenTelemetry 1.34.0/0.55b0 (2025-06-04) # Reference: https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3456 # TODO: Remove this patch after upgrading to version 1.34.0 or later -def _apply_starlette_instrumentation_patches() -> None: - """Apply patches to the Starlette instrumentation. +def _apply_starlette_version_patches() -> None: + """Apply version compatibility patches to the Starlette instrumentation. This patch modifies the instrumentation_dependencies method in the starlette instrumentation to loose an upper version constraint for auto-instrumentation @@ -53,3 +64,92 @@ def patched_init(self, app, **kwargs): _logger.debug("Successfully patched Starlette instrumentation_dependencies method") except Exception as exc: # pylint: disable=broad-except _logger.warning("Failed to apply Starlette instrumentation patches: %s", exc) + + +def _apply_starlette_code_attributes_patch() -> None: + """Starlette instrumentation patch for code attributes + + This patch modifies Starlette Route class to automatically apply + the record_code_attributes decorator to endpoint functions when + routes are created. + + The patch: + 1. Imports record_code_attributes decorator from AWS distro code correlation + 2. Hooks Starlette's Route.__init__ method during instrumentation + 3. Automatically decorates endpoint functions as routes are created + 4. Adds code.function.name, code.file.path, and code.line.number to spans + 5. Provides cleanup during uninstrumentation + """ + try: + # Import Starlette routing classes and AWS decorator + from starlette.routing import Route + + from amazon.opentelemetry.distro.code_correlation import record_code_attributes + from opentelemetry.instrumentation.starlette import StarletteInstrumentor + + # Store the original _instrument and _uninstrument methods + original_instrument = StarletteInstrumentor._instrument + original_uninstrument = StarletteInstrumentor._uninstrument + + # Store reference to original Route.__init__ + original_route_init = Route.__init__ + + def _decorate_endpoint(endpoint): + """Helper function to decorate an endpoint function with code attributes.""" + try: + if endpoint and callable(endpoint): + # Check if function is already decorated (avoid double decoration) + if not hasattr(endpoint, "_current_span_code_attributes_decorated"): + # Apply decorator + decorated_endpoint = record_code_attributes(endpoint) + # Mark as decorated to avoid double decoration + decorated_endpoint._current_span_code_attributes_decorated = True + decorated_endpoint._original_endpoint = endpoint + return decorated_endpoint + return endpoint + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply code attributes decorator to endpoint: %s", exc) + return endpoint + + def _wrapped_route_init(self, path, endpoint=None, **kwargs): + """Wrapped Route.__init__ method with code attributes decoration.""" + # Decorate endpoint if provided + if endpoint: + endpoint = _decorate_endpoint(endpoint) + + # Call the original Route.__init__ with decorated endpoint + return original_route_init(self, path, endpoint=endpoint, **kwargs) + + def patched_instrument(self, **kwargs): + """Patched _instrument method with Route.__init__ wrapping""" + # Store original Route.__init__ method if not already stored + if not hasattr(self, "_original_route_init"): + self._original_route_init = Route.__init__ + + # Wrap Route.__init__ with code attributes decoration + Route.__init__ = _wrapped_route_init + + # Call the original _instrument method + original_instrument(self, **kwargs) + + def patched_uninstrument(self, **kwargs): + """Patched _uninstrument method with Route.__init__ restoration""" + # Call the original _uninstrument method first + original_uninstrument(self, **kwargs) + + # Restore original Route.__init__ method if it exists + if hasattr(self, "_original_route_init"): + try: + Route.__init__ = self._original_route_init + delattr(self, "_original_route_init") + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to restore original Route.__init__ method: %s", exc) + + # Apply the patches to StarletteInstrumentor + StarletteInstrumentor._instrument = patched_instrument + StarletteInstrumentor._uninstrument = patched_uninstrument + + _logger.debug("Starlette instrumentation code attributes patch applied successfully") + + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.warning("Failed to apply Starlette code attributes patch: %s", exc) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_fastapi_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_fastapi_patches.py new file mode 100644 index 000000000..bf5afcb98 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_fastapi_patches.py @@ -0,0 +1,229 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import sys +from unittest.mock import patch + +from fastapi import APIRouter, FastAPI + +from amazon.opentelemetry.distro.patches._fastapi_patches import _apply_fastapi_instrumentation_patches +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry.test.test_base import TestBase + +# Store original methods at module level before any patches +_ORIGINAL_FASTAPI_INSTRUMENTOR_INSTRUMENT = FastAPIInstrumentor._instrument +_ORIGINAL_FASTAPI_INSTRUMENTOR_UNINSTRUMENT = FastAPIInstrumentor._uninstrument +_ORIGINAL_APIROUTER_ADD_API_ROUTE = APIRouter.add_api_route + + +class TestFastAPIPatchesRealApp(TestBase): + """Test FastAPI patches functionality.""" + + def setUp(self): + """Set up test fixtures.""" + super().setUp() + + # Restore original methods + APIRouter.add_api_route = _ORIGINAL_APIROUTER_ADD_API_ROUTE + FastAPIInstrumentor._instrument = _ORIGINAL_FASTAPI_INSTRUMENTOR_INSTRUMENT + FastAPIInstrumentor._uninstrument = _ORIGINAL_FASTAPI_INSTRUMENTOR_UNINSTRUMENT + + # Create FastAPI app + self.app = FastAPI() + + @self.app.get("/hello/{name}") + async def hello(name: str): + return {"message": f"Hello {name}!"} + + def tearDown(self): + """Clean up after tests.""" + super().tearDown() + + # Restore original methods + APIRouter.add_api_route = _ORIGINAL_APIROUTER_ADD_API_ROUTE + FastAPIInstrumentor._instrument = _ORIGINAL_FASTAPI_INSTRUMENTOR_INSTRUMENT + FastAPIInstrumentor._uninstrument = _ORIGINAL_FASTAPI_INSTRUMENTOR_UNINSTRUMENT + + # Clean up instrumentor attributes + instrumentor_instance = FastAPIInstrumentor() + for attr_name in list(vars(FastAPIInstrumentor).keys()): + if attr_name.startswith("_original_apirouter"): + delattr(FastAPIInstrumentor, attr_name) + + for attr_name in [attr for attr in dir(instrumentor_instance) if attr.startswith("_original_apirouter")]: + if hasattr(instrumentor_instance, attr_name): + delattr(instrumentor_instance, attr_name) + + try: + FastAPIInstrumentor().uninstrument() + except Exception: + pass + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_with_real_app(self, mock_get_status): + """Test FastAPI patches core functionality.""" + mock_get_status.return_value = True + original_add_api_route = _ORIGINAL_APIROUTER_ADD_API_ROUTE + + # Apply patches + _apply_fastapi_instrumentation_patches() + mock_get_status.assert_called_once() + + # Test method wrapping + instrumentor = FastAPIInstrumentor() + instrumentor._instrument() + + current_add_api_route = APIRouter.add_api_route + self.assertNotEqual(current_add_api_route, original_add_api_route) + + # Test app instrumentation + instrumentor.instrument_app(self.app) + self.assertIsNotNone(self.app) + + # Test uninstrumentation + instrumentor._uninstrument() + restored_add_api_route = APIRouter.add_api_route + self.assertEqual(restored_add_api_route, original_add_api_route) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_disabled(self, mock_get_status): + """Test FastAPI patches when disabled.""" + mock_get_status.return_value = False + + _apply_fastapi_instrumentation_patches() + instrumentor = FastAPIInstrumentor() + instrumentor.instrument_app(self.app) + + mock_get_status.assert_called_once() + self.assertIsNotNone(self.app) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_import_error_handling(self, mock_get_status): + """Test FastAPI patches with import errors.""" + mock_get_status.return_value = True + original_modules = sys.modules.copy() + + try: + # Simulate import error + modules_to_remove = [ + "fastapi.routing", + "amazon.opentelemetry.distro.code_correlation", + "opentelemetry.instrumentation.fastapi", + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + _apply_fastapi_instrumentation_patches() + mock_get_status.assert_called_once() + + finally: + sys.modules.clear() + sys.modules.update(original_modules) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_endpoint_decoration(self, mock_get_status): + """Test endpoint decoration functionality.""" + mock_get_status.return_value = True + + instrumentor = FastAPIInstrumentor() + instrumentor.instrument_app(self.app) + _apply_fastapi_instrumentation_patches() + instrumentor._instrument() + + # Test adding routes + async def async_endpoint(): + return {"message": "async endpoint"} + + router = APIRouter() + router.add_api_route("/test_async", async_endpoint, methods=["GET"]) + self.app.include_router(router) + + self.assertTrue(len(self.app.routes) > 0) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_uninstrument_error_handling(self, mock_get_status): + """Test uninstrument error handling.""" + mock_get_status.return_value = True + + instrumentor = FastAPIInstrumentor() + _apply_fastapi_instrumentation_patches() + instrumentor._instrument() + + # Break stored references to trigger error handling + if hasattr(instrumentor, "_original_apirouter"): + instrumentor._original_apirouter = None + + try: + instrumentor._uninstrument() + except Exception: + pass # Expected to handle gracefully + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_code_correlation_import_error(self, mock_get_status): + """Test code correlation import error handling.""" + mock_get_status.return_value = True + original_modules = sys.modules.copy() + + try: + # Remove code_correlation module to simulate import error + modules_to_remove = [ + "amazon.opentelemetry.distro.code_correlation", + "amazon.opentelemetry.distro.code_correlation.record_code_attributes", + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + instrumentor = FastAPIInstrumentor() + _apply_fastapi_instrumentation_patches() + instrumentor._instrument() + + finally: + sys.modules.clear() + sys.modules.update(original_modules) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_double_decoration_prevention(self, mock_get_status): + """Test prevention of double decoration.""" + mock_get_status.return_value = True + + _apply_fastapi_instrumentation_patches() + instrumentor = FastAPIInstrumentor() + instrumentor._instrument() + + # Create pre-decorated endpoint + async def test_endpoint(): + return {"message": "test"} + + test_endpoint._current_span_code_attributes_decorated = True + + router = APIRouter() + router.add_api_route("/test_double", test_endpoint, methods=["GET"]) + self.app.include_router(router) + + self.assertTrue(len(self.app.routes) > 0) + + @patch("amazon.opentelemetry.distro.patches._fastapi_patches.get_code_correlation_enabled_status") + def test_fastapi_patches_none_endpoint_handling(self, mock_get_status): + """Test handling of None endpoints.""" + mock_get_status.return_value = True + + _apply_fastapi_instrumentation_patches() + instrumentor = FastAPIInstrumentor() + instrumentor._instrument() + + router = APIRouter() + + # Test None endpoint handling + try: + router.add_api_route("/test_none", None, methods=["GET"]) + except Exception: + pass # Expected to handle gracefully + + try: + router.add_api_route("/test_string", "not_callable", methods=["GET"]) + except Exception: + pass # Expected to handle gracefully diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py index 3de5f0bde..16f0a72cd 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_starlette_patches.py @@ -3,7 +3,10 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from amazon.opentelemetry.distro.patches._starlette_patches import _apply_starlette_instrumentation_patches +from amazon.opentelemetry.distro.patches._starlette_patches import ( + _apply_starlette_code_attributes_patch, + _apply_starlette_instrumentation_patches, +) class TestStarlettePatch(TestCase): @@ -83,58 +86,558 @@ def test_starlette_patch_handles_import_error(self, mock_logger): args = mock_logger.warning.call_args[0] self.assertIn("Failed to apply Starlette instrumentation patches", args[0]) + +class TestStarletteCodeAttributesPatch(TestCase): + """Test the Starlette code attributes instrumentation patches using real Route class.""" + + def setUp(self): + """Set up test fixtures.""" + + # Sample endpoint functions for testing + def sample_endpoint(): + return {"message": "Hello World"} + + def another_endpoint(): + return {"message": "Another endpoint"} + + self.sample_endpoint = sample_endpoint + self.another_endpoint = another_endpoint + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_applied_successfully(self, mock_logger): + """Test that the code attributes patch is applied successfully using real Route class.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + # Mock the code correlation decorator + mock_record_code_attributes = MagicMock() + + def mock_decorator(func): + """Mock decorator that marks function as decorated.""" + + # Create a wrapper that preserves the original function + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Mark as decorated + wrapper._current_span_code_attributes_decorated = True + wrapper._original_endpoint = func + wrapper.__name__ = getattr(func, "__name__", "decorated_endpoint") + return wrapper + + mock_record_code_attributes.side_effect = mock_decorator + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + # Store original Route.__init__ + original_route_init = Route.__init__ + + try: + # Apply the patch + _apply_starlette_code_attributes_patch() + + # Verify the instrumentor methods were patched + self.assertTrue(hasattr(mock_instrumentor_class, "_instrument")) + self.assertTrue(hasattr(mock_instrumentor_class, "_uninstrument")) + + # Call the patched _instrument method to set up instrumentation + mock_instrumentor._instrument() + + # Verify Route.__init__ was modified + self.assertNotEqual(Route.__init__, original_route_init) + + # Create a route with the patched Route class + route = Route("/test", endpoint=self.sample_endpoint) + + # Verify the endpoint was decorated + mock_record_code_attributes.assert_called_once_with(self.sample_endpoint) + + # Test that the route was created successfully + self.assertEqual(route.path, "/test") + self.assertIsNotNone(route.endpoint) + + # Verify the endpoint is decorated + self.assertTrue(hasattr(route.endpoint, "_current_span_code_attributes_decorated")) + self.assertEqual(route.endpoint._original_endpoint, self.sample_endpoint) + + # Test uninstrumentation + mock_instrumentor._uninstrument() + + # Verify Route.__init__ was restored + self.assertEqual(Route.__init__, original_route_init) + + # Verify logging + mock_logger.debug.assert_called_with( + "Starlette instrumentation code attributes patch applied successfully" + ) + + finally: + # Restore original Route.__init__ + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_with_none_endpoint(self, mock_logger): + """Test that the patch handles None endpoint gracefully.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + mock_instrumentor_class = MagicMock() + mock_instrumentor = MagicMock() + mock_instrumentor_class.return_value = mock_instrumentor + + mock_record_code_attributes = MagicMock() + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor_class._instrument(mock_instrumentor) + + # Create route with None endpoint + route = Route("/test", endpoint=None) + + # Verify no decoration was attempted + mock_record_code_attributes.assert_not_called() + + # Verify route was created successfully + self.assertEqual(route.path, "/test") + self.assertIsNone(route.endpoint) + + finally: + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_avoids_double_decoration(self, mock_logger): + """Test that the patch avoids double decoration of endpoints.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + mock_instrumentor_class = MagicMock() + mock_instrumentor = MagicMock() + mock_instrumentor_class.return_value = mock_instrumentor + + mock_record_code_attributes = MagicMock() + + # Create an already decorated endpoint + def already_decorated_endpoint(): + return {"message": "Already decorated"} + + already_decorated_endpoint._current_span_code_attributes_decorated = True + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor_class._instrument(mock_instrumentor) + + # Create route with already decorated endpoint + route = Route("/test", endpoint=already_decorated_endpoint) + + # Verify no additional decoration was attempted + mock_record_code_attributes.assert_not_called() + + # Verify route was created successfully + self.assertEqual(route.path, "/test") + self.assertEqual(route.endpoint, already_decorated_endpoint) + + finally: + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_handles_non_callable_endpoint(self, mock_logger): + """Test that the patch handles non-callable endpoints gracefully.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + mock_instrumentor_class = MagicMock() + mock_instrumentor = MagicMock() + mock_instrumentor_class.return_value = mock_instrumentor + + mock_record_code_attributes = MagicMock() + + # Non-callable endpoint + non_callable_endpoint = "not_a_function" + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor_class._instrument(mock_instrumentor) + + # Create route with non-callable endpoint + route = Route("/test", endpoint=non_callable_endpoint) + + # Verify no decoration was attempted + mock_record_code_attributes.assert_not_called() + + # Verify route was created successfully + self.assertEqual(route.path, "/test") + self.assertEqual(route.endpoint, non_callable_endpoint) + + finally: + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_handles_decorator_error(self, mock_logger): + """Test that the patch handles decorator errors gracefully.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + # Mock decorator that raises exception + mock_record_code_attributes = MagicMock() + mock_record_code_attributes.side_effect = RuntimeError("Decorator failed") + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor._instrument() + + # Create route - should not raise exception despite decorator error + route = Route("/test", endpoint=self.sample_endpoint) + + # Verify route was created successfully with original endpoint + self.assertEqual(route.path, "/test") + self.assertEqual(route.endpoint, self.sample_endpoint) + + # Verify warning was logged + mock_logger.warning.assert_called() + args = mock_logger.warning.call_args[0] + self.assertIn("Failed to apply code attributes decorator to endpoint", args[0]) + + finally: + Route.__init__ = original_route_init + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") - def test_starlette_patch_handles_attribute_error(self, mock_logger): - """Test that the patch handles attribute errors gracefully.""" + def test_code_attributes_patch_uninstrument_restores_original(self, mock_logger): + """Test that uninstrumentation properly restores the original Route.__init__.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + mock_record_code_attributes = MagicMock() + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() - # Create a metaclass that raises AttributeError when setting class attributes - class ErrorMeta(type): - def __setattr__(cls, name, value): - if name == "instrumentation_dependencies": - raise AttributeError("Cannot set attribute") - super().__setattr__(name, value) + # Apply instrumentation - this is when Route.__init__ gets wrapped + mock_instrumentor._instrument() + patched_init = Route.__init__ - # Create a class with the error-raising metaclass - class MockStarletteInstrumentor(metaclass=ErrorMeta): - pass + # Verify Route.__init__ was patched + self.assertNotEqual(patched_init, original_route_init) - # Create a mock module - mock_starlette_module = MagicMock() - mock_starlette_module.StarletteInstrumentor = MockStarletteInstrumentor + # Apply uninstrumentation + mock_instrumentor._uninstrument() - with patch.dict("sys.modules", {"opentelemetry.instrumentation.starlette": mock_starlette_module}): + # Verify Route.__init__ was restored + self.assertEqual(Route.__init__, original_route_init) + + finally: + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_route_with_kwargs(self, mock_logger): + """Test that the patch works with routes that have additional kwargs.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + mock_record_code_attributes = MagicMock() + mock_record_code_attributes.side_effect = lambda func: func + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor._instrument() + + # Create route with additional kwargs + route = Route("/test", endpoint=self.sample_endpoint, methods=["GET", "POST"], name="test_route") + + # Verify decorator was called + mock_record_code_attributes.assert_called_once_with(self.sample_endpoint) + + # Verify route was created successfully with all attributes + self.assertEqual(route.path, "/test") + # Starlette automatically adds HEAD when GET is specified + self.assertIn("GET", route.methods) + self.assertIn("POST", route.methods) + self.assertEqual(route.name, "test_route") + self.assertIsNotNone(route.endpoint) + + finally: + Route.__init__ = original_route_init + + @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") + def test_code_attributes_patch_handles_import_error(self, mock_logger): + """Test that the patch handles import errors gracefully.""" + # Mock import failure + with patch.dict("sys.modules", {"starlette.routing": None}): # This should not raise an exception - _apply_starlette_instrumentation_patches() + _apply_starlette_code_attributes_patch() # Verify warning was logged mock_logger.warning.assert_called_once() args = mock_logger.warning.call_args[0] - self.assertIn("Failed to apply Starlette instrumentation patches", args[0]) - - def test_starlette_patch_logs_failure_with_no_logger_patch(self): # pylint: disable=no-self-use - """Test that the patch handles exceptions gracefully without logger mock.""" - # Mock the import to fail - with patch.dict("sys.modules", {"opentelemetry.instrumentation.starlette": None}): - # This should not raise an exception even without logger mock - _apply_starlette_instrumentation_patches() + self.assertIn("Failed to apply Starlette code attributes patch", args[0]) @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger") - def test_starlette_patch_with_exception_during_import(self, mock_logger): - """Test that the patch handles exceptions during import.""" - - # Create a module that raises exception when accessing StarletteInstrumentor - class FailingModule: - @property - def StarletteInstrumentor(self): # pylint: disable=invalid-name - raise RuntimeError("Import failed") + def test_code_attributes_patch_handles_general_exception(self, mock_logger): + """Test that the patch handles general exceptions gracefully.""" - failing_module = FailingModule() + # Mock import to cause exception - simulate an issue in module loading + def failing_import(*args, **kwargs): + raise RuntimeError("General failure") - with patch.dict("sys.modules", {"opentelemetry.instrumentation.starlette": failing_module}): + with patch("builtins.__import__", side_effect=failing_import): # This should not raise an exception - _apply_starlette_instrumentation_patches() + _apply_starlette_code_attributes_patch() # Verify warning was logged mock_logger.warning.assert_called_once() args = mock_logger.warning.call_args[0] - self.assertIn("Failed to apply Starlette instrumentation patches", args[0]) + self.assertIn("Failed to apply Starlette code attributes patch", args[0]) + + def test_code_attributes_patch_multiple_routes(self): + """Test that the patch works correctly with multiple routes.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + mock_record_code_attributes = MagicMock() + mock_record_code_attributes.side_effect = lambda func: func + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor._instrument() + + # Create multiple routes + route1 = Route("/test1", endpoint=self.sample_endpoint) + route2 = Route("/test2", endpoint=self.another_endpoint) + + # Verify both endpoints were decorated + self.assertEqual(mock_record_code_attributes.call_count, 2) + mock_record_code_attributes.assert_any_call(self.sample_endpoint) + mock_record_code_attributes.assert_any_call(self.another_endpoint) + + # Verify routes were created successfully + self.assertEqual(route1.path, "/test1") + self.assertEqual(route2.path, "/test2") + self.assertIsNotNone(route1.endpoint) + self.assertIsNotNone(route2.endpoint) + + finally: + Route.__init__ = original_route_init + + def test_code_attributes_patch_route_class_methods(self): + """Test that the patch preserves Route class methods and attributes.""" + try: + from starlette.routing import Route + except ImportError: + self.skipTest("Starlette not available") + + # Create a mock StarletteInstrumentor class with proper methods + class MockStarletteInstrumentor: + def __init__(self): + pass + + def _instrument(self, **kwargs): + pass + + def _uninstrument(self, **kwargs): + pass + + mock_instrumentor_class = MockStarletteInstrumentor + mock_instrumentor = MockStarletteInstrumentor() + + mock_record_code_attributes = MagicMock() + mock_record_code_attributes.side_effect = lambda func: func + + with patch.dict( + "sys.modules", + { + "opentelemetry.instrumentation.starlette": MagicMock(StarletteInstrumentor=mock_instrumentor_class), + "amazon.opentelemetry.distro.code_correlation": MagicMock( + record_code_attributes=mock_record_code_attributes + ), + }, + ): + original_route_init = Route.__init__ + + try: + _apply_starlette_code_attributes_patch() + mock_instrumentor._instrument() + + # Create a route + route = Route("/test", endpoint=self.sample_endpoint, methods=["GET"]) + + # Verify Route methods still work + self.assertTrue(hasattr(route, "matches")) + self.assertTrue(hasattr(route, "url_path_for")) + + # Test that the route still functions as expected + self.assertEqual(route.path, "/test") + # Starlette automatically adds HEAD when GET is specified + self.assertIn("GET", route.methods) + self.assertIn("HEAD", route.methods) + + finally: + Route.__init__ = original_route_init diff --git a/dev-requirements.txt b/dev-requirements.txt index 6383be420..646ad8e24 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -15,3 +15,6 @@ requests==2.32.4 ruamel.yaml==0.17.21 flaky==3.7.0 botocore==1.34.67 +flask>=2.0.0 +fastapi>=0.68.0 +starlette>=0.14.2 From a80e98d1eaad4ddd65824ca719f93d7c0dd1f72d Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 3 Oct 2025 21:43:07 -0700 Subject: [PATCH 4/8] improve coverage --- ..._aws_cw_otlp_batch_log_record_processor.py | 80 +++++++++ .../sampler/test_aws_xray_remote_sampler.py | 140 ++++++++++++++++ .../sampler/test_aws_xray_sampling_client.py | 158 ++++++++++++++++++ 3 files changed, 378 insertions(+) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py index f22c18492..748ddf199 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py @@ -272,6 +272,86 @@ def test_force_flush_exports_only_one_batch(self, _, __, ___): exported_batch = args[0] self.assertEqual(len(exported_batch), 5) + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.set_value") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor._logger") + def test_export_handles_exception_gracefully(self, mock_logger, _, __, ___): + """Tests that exceptions during export are caught and logged""" + # Setup exporter to raise an exception + self.mock_exporter.export.side_effect = Exception("Export failed") + + # Add logs to queue + test_logs = self.generate_test_log_data(log_body="test message", count=2) + for log in test_logs: + self.processor._queue.appendleft(log) + + # Call _export - should not raise exception + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + + # Verify exception was logged + mock_logger.exception.assert_called_once() + call_args = mock_logger.exception.call_args[0] + self.assertIn("Exception while exporting logs:", call_args[0]) + + # Queue should be empty even though export failed + self.assertEqual(len(self.processor._queue), 0) + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor._logger") + def test_estimate_log_size_debug_logging_on_depth_exceeded(self, mock_logger): + """Tests that debug logging occurs when depth limit is exceeded""" + # Create deeply nested structure that exceeds depth limit + depth_limit = 1 + log_body = {"level1": {"level2": {"level3": {"level4": "this should trigger debug log"}}}} + + test_logs = self.generate_test_log_data(log_body=log_body, count=1) + + # Call with limited depth that will be exceeded + self.processor._estimate_log_size(log=test_logs[0], depth=depth_limit) + + # Verify debug logging was called + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args[0] + self.assertIn("Max log depth of", call_args[0]) + self.assertIn("exceeded", call_args[0]) + + def test_estimate_utf8_size_static_method(self): + """Tests the _estimate_utf8_size static method with various strings""" + # Test ASCII only string + ascii_result = AwsCloudWatchOtlpBatchLogRecordProcessor._estimate_utf8_size("hello") + self.assertEqual(ascii_result, 5) # 5 ASCII chars = 5 bytes + + # Test mixed ASCII and non-ASCII + mixed_result = AwsCloudWatchOtlpBatchLogRecordProcessor._estimate_utf8_size("café") + self.assertEqual(mixed_result, 7) # 3 ASCII + 1 non-ASCII (4 bytes) = 7 bytes + + # Test non-ASCII only + non_ascii_result = AwsCloudWatchOtlpBatchLogRecordProcessor._estimate_utf8_size("深入") + self.assertEqual(non_ascii_result, 8) # 2 non-ASCII chars * 4 bytes = 8 bytes + + # Test empty string + empty_result = AwsCloudWatchOtlpBatchLogRecordProcessor._estimate_utf8_size("") + self.assertEqual(empty_result, 0) + + def test_constructor_with_custom_parameters(self): + """Tests constructor with custom parameters""" + custom_processor = AwsCloudWatchOtlpBatchLogRecordProcessor( + exporter=self.mock_exporter, + schedule_delay_millis=5000, + max_export_batch_size=100, + export_timeout_millis=10000, + max_queue_size=2000, + ) + + # Verify exporter is stored + self.assertEqual(custom_processor._exporter, self.mock_exporter) + + # Verify parameters are passed to parent constructor + self.assertEqual(custom_processor._max_export_batch_size, 100) + @staticmethod def generate_test_log_data( log_body, diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py index c83c4eab0..549e8e637 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py @@ -304,3 +304,143 @@ def test_non_parent_based_xray_sampler_updates_statistics_thrice_for_one_parent_ non_parent_based_xray_sampler._rules_timer.cancel() non_parent_based_xray_sampler._targets_timer.cancel() + + def test_create_remote_sampler_with_none_resource(self): + """Tests creating remote sampler with None resource""" + with patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger") as mock_logger: + self.rs = AwsXRayRemoteSampler(resource=None) + + # Verify warning was logged for None resource + mock_logger.warning.assert_called_once_with( + "OTel Resource provided is `None`. Defaulting to empty resource" + ) + + # Verify empty resource was set + self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource) + self.assertEqual(len(self.rs._root._root._AwsXRayRemoteSampler__resource.attributes), 0) + + def test_create_remote_sampler_with_small_polling_interval(self): + """Tests creating remote sampler with polling interval < 10""" + with patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger") as mock_logger: + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty(), polling_interval=5) # Less than 10 + + # Verify info log was called for small polling interval + mock_logger.info.assert_any_call("`polling_interval` is `None` or too small. Defaulting to %s", 300) + + # Verify default polling interval was set + self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300) + + def test_create_remote_sampler_with_none_endpoint(self): + """Tests creating remote sampler with None endpoint""" + with patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger") as mock_logger: + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty(), endpoint=None) + + # Verify info log was called for None endpoint + mock_logger.info.assert_any_call("`endpoint` is `None`. Defaulting to %s", "http://127.0.0.1:2000") + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_should_sample_with_expired_rule_cache(self, mock_post=None): + """Tests should_sample behavior when rule cache is expired""" + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + + # Mock rule cache to be expired + with patch.object(self.rs._root._root._AwsXRayRemoteSampler__rule_cache, "expired", return_value=True): + with patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger") as mock_logger: + # Call should_sample when cache is expired + result = self.rs._root._root.should_sample(None, 0, "test_span") + + # Verify debug log was called + mock_logger.debug.assert_called_once_with("Rule cache is expired so using fallback sampling strategy") + + # Verify fallback sampler was used (should return some result) + self.assertIsNotNone(result) + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_refresh_rules_when_targets_require_it(self, mock_post=None): + """Tests that sampling rules are refreshed when targets polling indicates it""" + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + + # Mock the rule cache update_sampling_targets to return refresh_rules=True + with patch.object( + self.rs._root._root._AwsXRayRemoteSampler__rule_cache, + "update_sampling_targets", + return_value=(True, None), # refresh_rules=True, min_polling_interval=None + ): + # Mock get_and_update_sampling_rules to track if it was called + with patch.object( + self.rs._root._root, "_AwsXRayRemoteSampler__get_and_update_sampling_rules" + ) as mock_update_rules: + # Call the method that should trigger rule refresh + self.rs._root._root._AwsXRayRemoteSampler__get_and_update_sampling_targets() + + # Verify that rules were refreshed + mock_update_rules.assert_called_once() + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_update_target_polling_interval(self, mock_post=None): + """Tests that target polling interval is updated when targets polling returns new interval""" + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + + # Mock the rule cache update_sampling_targets to return new polling interval + new_interval = 500 + with patch.object( + self.rs._root._root._AwsXRayRemoteSampler__rule_cache, + "update_sampling_targets", + return_value=(False, new_interval), # refresh_rules=False, min_polling_interval=500 + ): + # Store original interval + original_interval = self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval + + # Call the method that should update polling interval + self.rs._root._root._AwsXRayRemoteSampler__get_and_update_sampling_targets() + + # Verify that polling interval was updated + self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, new_interval) + self.assertNotEqual(original_interval, new_interval) + + def test_generate_client_id_format(self): + """Tests that client ID generation produces correctly formatted hex string""" + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + client_id = self.rs._root._root._AwsXRayRemoteSampler__client_id + + # Verify client ID is 24 characters long + self.assertEqual(len(client_id), 24) + + # Verify all characters are valid hex characters + valid_hex_chars = set("0123456789abcdef") + for char in client_id: + self.assertIn(char, valid_hex_chars) + + def test_internal_sampler_get_description(self): + """Tests get_description method of internal _AwsXRayRemoteSampler""" + internal_sampler = _AwsXRayRemoteSampler(resource=Resource.get_empty()) + + try: + description = internal_sampler.get_description() + self.assertEqual(description, "_AwsXRayRemoteSampler{remote sampling with AWS X-Ray}") + finally: + # Clean up timers + internal_sampler._rules_timer.cancel() + internal_sampler._targets_timer.cancel() + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_rule_and_target_pollers_start_correctly(self, mock_post=None): + """Tests that both rule and target pollers are started and configured correctly""" + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + + # Verify timers are created and started + self.assertIsNotNone(self.rs._root._root._rules_timer) + self.assertIsNotNone(self.rs._root._root._targets_timer) + + # Verify timers are daemon threads + self.assertTrue(self.rs._root._root._rules_timer.daemon) + self.assertTrue(self.rs._root._root._targets_timer.daemon) + + # Verify jitter values are within expected ranges + rule_jitter = self.rs._root._root._AwsXRayRemoteSampler__rule_polling_jitter + target_jitter = self.rs._root._root._AwsXRayRemoteSampler__target_polling_jitter + + self.assertGreaterEqual(rule_jitter, 0.0) + self.assertLessEqual(rule_jitter, 5.0) + self.assertGreaterEqual(target_jitter, 0.0) + self.assertLessEqual(target_jitter, 0.1) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py index 37bc2fe08..1cba69077 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py @@ -238,3 +238,161 @@ def test_urls_excluded_from_sampling(self): URLLib3Instrumentor().uninstrument() RequestsInstrumentor().uninstrument() + + def test_constructor_with_none_endpoint(self): + """Tests constructor behavior when endpoint is None""" + with self.assertLogs(_sampling_client_logger, level="ERROR") as cm: + # Constructor will log error but then crash on concatenation + with self.assertRaises(TypeError): + _AwsXRaySamplingClient(endpoint=None) + + # Verify error log was called before the crash + self.assertIn("endpoint must be specified", cm.output[0]) + + def test_constructor_with_log_level(self): + """Tests constructor sets log level when specified""" + original_level = _sampling_client_logger.level + try: + _AwsXRaySamplingClient("http://test.com", log_level=logging.DEBUG) + self.assertEqual(_sampling_client_logger.level, logging.DEBUG) + finally: + # Reset log level + _sampling_client_logger.setLevel(original_level) + + @patch("requests.Session.post") + def test_get_sampling_rules_none_response(self, mock_post): + """Tests get_sampling_rules when response is None""" + mock_post.return_value = None + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="ERROR") as cm: + sampling_rules = client.get_sampling_rules() + + # Verify error log and empty result + self.assertIn("GetSamplingRules response is None", cm.output[0]) + self.assertEqual(len(sampling_rules), 0) + + @patch("requests.Session.post") + def test_get_sampling_rules_request_exception(self, mock_post): + """Tests get_sampling_rules when RequestException occurs""" + mock_post.side_effect = requests.exceptions.RequestException("Connection error") + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="ERROR") as cm: + sampling_rules = client.get_sampling_rules() + + # Verify error log and empty result + self.assertIn("Request error occurred", cm.output[0]) + self.assertIn("Connection error", cm.output[0]) + self.assertEqual(len(sampling_rules), 0) + + @patch("requests.Session.post") + def test_get_sampling_rules_json_decode_error(self, mock_post): + """Tests get_sampling_rules when JSON decode error occurs""" + # Mock response that raises JSONDecodeError when .json() is called + mock_response = mock_post.return_value + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "doc", 0) + + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="ERROR") as cm: + sampling_rules = client.get_sampling_rules() + + # Verify error log and empty result + self.assertIn("Error in decoding JSON response", cm.output[0]) + self.assertEqual(len(sampling_rules), 0) + + @patch("requests.Session.post") + def test_get_sampling_rules_general_exception(self, mock_post): + """Tests get_sampling_rules when general exception occurs""" + mock_post.side_effect = Exception("Unexpected error") + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="ERROR") as cm: + sampling_rules = client.get_sampling_rules() + + # Verify error log and empty result + self.assertIn("Error occurred when attempting to fetch rules", cm.output[0]) + self.assertIn("Unexpected error", cm.output[0]) + self.assertEqual(len(sampling_rules), 0) + + @patch("requests.Session.post") + def test_get_sampling_targets_none_response(self, mock_post): + """Tests get_sampling_targets when response is None""" + mock_post.return_value = None + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="DEBUG") as cm: + response = client.get_sampling_targets([]) + + # Verify debug log and default response + self.assertIn("GetSamplingTargets response is None", cm.output[0]) + self.assertEqual(response.SamplingTargetDocuments, []) + self.assertEqual(response.UnprocessedStatistics, []) + self.assertEqual(response.LastRuleModification, 0.0) + + @patch("requests.Session.post") + def test_get_sampling_targets_invalid_response_format(self, mock_post): + """Tests get_sampling_targets when response format is invalid""" + # Missing required fields + mock_post.return_value.configure_mock(**{"json.return_value": {"InvalidField": "value"}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="DEBUG") as cm: + response = client.get_sampling_targets([]) + + # Verify debug log and default response + self.assertIn("getSamplingTargets response is invalid", cm.output[0]) + self.assertEqual(response.SamplingTargetDocuments, []) + self.assertEqual(response.UnprocessedStatistics, []) + self.assertEqual(response.LastRuleModification, 0.0) + + @patch("requests.Session.post") + def test_get_sampling_targets_request_exception(self, mock_post): + """Tests get_sampling_targets when RequestException occurs""" + mock_post.side_effect = requests.exceptions.RequestException("Network error") + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="DEBUG") as cm: + response = client.get_sampling_targets([]) + + # Verify debug log and default response + self.assertIn("Request error occurred", cm.output[0]) + self.assertIn("Network error", cm.output[0]) + self.assertEqual(response.SamplingTargetDocuments, []) + self.assertEqual(response.UnprocessedStatistics, []) + self.assertEqual(response.LastRuleModification, 0.0) + + @patch("requests.Session.post") + def test_get_sampling_targets_json_decode_error(self, mock_post): + """Tests get_sampling_targets when JSON decode error occurs""" + # Mock response that raises JSONDecodeError when .json() is called + mock_response = mock_post.return_value + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "doc", 0) + + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="DEBUG") as cm: + response = client.get_sampling_targets([]) + + # Verify debug log and default response + self.assertIn("Error in decoding JSON response", cm.output[0]) + self.assertEqual(response.SamplingTargetDocuments, []) + self.assertEqual(response.UnprocessedStatistics, []) + self.assertEqual(response.LastRuleModification, 0.0) + + @patch("requests.Session.post") + def test_get_sampling_targets_general_exception(self, mock_post): + """Tests get_sampling_targets when general exception occurs""" + mock_post.side_effect = Exception("Unexpected error") + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + + with self.assertLogs(_sampling_client_logger, level="DEBUG") as cm: + response = client.get_sampling_targets([]) + + # Verify debug log and default response + self.assertIn("Error occurred when attempting to fetch targets", cm.output[0]) + self.assertIn("Unexpected error", cm.output[0]) + self.assertEqual(response.SamplingTargetDocuments, []) + self.assertEqual(response.UnprocessedStatistics, []) + self.assertEqual(response.LastRuleModification, 0.0) From 4ea6fd2662ae8e4bae3279cb39a90e9a10f42182 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 3 Oct 2025 22:31:58 -0700 Subject: [PATCH 5/8] fix lint, add dev dependencies --- .pylintrc | 2 +- .../distro/patches/_fastapi_patches.py | 8 +++++--- .../distro/patches/_flask_patches.py | 12 +++++++----- .../distro/patches/_starlette_patches.py | 12 ++++++++---- .../distro/patches/test_flask_patches.py | 15 +++++++-------- .../sampler/test_aws_xray_remote_sampler.py | 10 +++++++--- eachdist.ini | 2 +- tox.ini | 3 +++ 8 files changed, 39 insertions(+), 25 deletions(-) diff --git a/.pylintrc b/.pylintrc index 94c9656e4..ea6886f45 100644 --- a/.pylintrc +++ b/.pylintrc @@ -7,7 +7,7 @@ extension-pkg-whitelist=cassandra # Add list of files or directories to be excluded. They should be base names, not # paths. -ignore=CVS,gen,Dockerfile,docker-compose.yml,README.md,requirements.txt,mock_collector_service_pb2.py,mock_collector_service_pb2.pyi,mock_collector_service_pb2_grpc.py +ignore=CVS,gen,Dockerfile,docker-compose.yml,README.md,requirements.txt,mock_collector_service_pb2.py,mock_collector_service_pb2.pyi,mock_collector_service_pb2_grpc.py,pyproject.toml,db.sqlite3 # Add files or directories matching the regex patterns to be excluded. The # regex matches against base names, not paths. diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py index f3ee8dbca..68d56c79b 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py @@ -36,10 +36,12 @@ def _apply_fastapi_code_attributes_patch() -> None: """ try: # Import FastAPI instrumentation classes and AWS decorator - from fastapi import routing + from fastapi import routing # pylint: disable=import-outside-toplevel - from amazon.opentelemetry.distro.code_correlation import record_code_attributes - from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + from amazon.opentelemetry.distro.code_correlation import ( # pylint: disable=import-outside-toplevel + record_code_attributes, + ) + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor # pylint: disable=import-outside-toplevel # Store the original _instrument and _uninstrument methods original_instrument = FastAPIInstrumentor._instrument diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py index b25c34b50..d1067081a 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py @@ -20,7 +20,7 @@ def _apply_flask_instrumentation_patches() -> None: _apply_flask_code_attributes_patch() -def _apply_flask_code_attributes_patch() -> None: +def _apply_flask_code_attributes_patch() -> None: # pylint: disable=too-many-statements """Flask instrumentation patch for code attributes This patch modifies the Flask instrumentation to automatically apply @@ -37,10 +37,12 @@ def _apply_flask_code_attributes_patch() -> None: """ try: # Import Flask instrumentation classes and AWS decorator - import flask + import flask # pylint: disable=import-outside-toplevel - from amazon.opentelemetry.distro.code_correlation import record_code_attributes - from opentelemetry.instrumentation.flask import FlaskInstrumentor + from amazon.opentelemetry.distro.code_correlation import ( # pylint: disable=import-outside-toplevel + record_code_attributes, + ) + from opentelemetry.instrumentation.flask import FlaskInstrumentor # pylint: disable=import-outside-toplevel # Store the original _instrument and _uninstrument methods original_instrument = FlaskInstrumentor._instrument @@ -79,7 +81,7 @@ def _wrapped_dispatch_request(self): """Wrapped Flask.dispatch_request method to handle deferred view function binding.""" try: # Get the current request context - from flask import request + from flask import request # pylint: disable=import-outside-toplevel # Check if there's an endpoint for this request endpoint = request.endpoint diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py index 00d0d89b1..91f633bd1 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_starlette_patches.py @@ -82,10 +82,14 @@ def _apply_starlette_code_attributes_patch() -> None: """ try: # Import Starlette routing classes and AWS decorator - from starlette.routing import Route - - from amazon.opentelemetry.distro.code_correlation import record_code_attributes - from opentelemetry.instrumentation.starlette import StarletteInstrumentor + from starlette.routing import Route # pylint: disable=import-outside-toplevel + + from amazon.opentelemetry.distro.code_correlation import ( # pylint: disable=import-outside-toplevel + record_code_attributes, + ) + from opentelemetry.instrumentation.starlette import ( # pylint: disable=import-outside-toplevel + StarletteInstrumentor, + ) # Store the original _instrument and _uninstrument methods original_instrument = StarletteInstrumentor._instrument diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py index 3a9f1896d..effc1475b 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_flask_patches.py @@ -39,9 +39,9 @@ def setUp(self): self.app = flask.Flask(__name__) # Add test routes - @self.app.route("/hello/") - def hello(name): - return f"Hello {name}!" + @self.app.route("/hello") + def hello(): + return "Hello!" @self.app.route("/error") def error_endpoint(): @@ -120,7 +120,7 @@ def test_flask_patches_with_real_app(self, mock_get_status): # Test a request to trigger the patches instrumentor.instrument_app(self.app) - resp = self.client.get("/hello/world") + resp = self.client.get("/hello") self.assertEqual(200, resp.status_code) # Test uninstrumentation - this should restore original Flask methods @@ -152,16 +152,16 @@ def test_flask_patches_disabled(self, mock_get_status): mock_get_status.assert_called_once() # Make a request - resp = self.client.get("/hello/world") + resp = self.client.get("/hello") self.assertEqual(200, resp.status_code) - self.assertEqual([b"Hello world!"], list(resp.response)) + self.assertEqual([b"Hello!"], list(resp.response)) # Check spans were still generated (normal instrumentation) spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 1) span = spans[0] - self.assertEqual(span.name, "GET /hello/") + self.assertEqual(span.name, "GET /hello") self.assertEqual(span.kind, trace.SpanKind.SERVER) # Clean up - avoid Flask instrumentation issues @@ -226,7 +226,6 @@ def lambda_func(): self.app.add_url_rule("/test_lambda", "test_lambda", lambda_func) # Clean up - don't call uninstrument_app to avoid Flask instrumentation issues - pass @patch("amazon.opentelemetry.distro.patches._flask_patches.get_code_correlation_enabled_status") def test_flask_patches_dispatch_request_coverage(self, mock_get_status): diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py index 549e8e637..5533a2c93 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py @@ -344,10 +344,14 @@ def test_should_sample_with_expired_rule_cache(self, mock_post=None): self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) # Mock rule cache to be expired - with patch.object(self.rs._root._root._AwsXRayRemoteSampler__rule_cache, "expired", return_value=True): - with patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger") as mock_logger: + with patch.object( + self.rs._root._root._AwsXRayRemoteSampler__rule_cache, "expired", return_value=True + ): # pylint: disable=not-context-manager + with patch( + "amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._logger" + ) as mock_logger: # pylint: disable=not-context-manager # Call should_sample when cache is expired - result = self.rs._root._root.should_sample(None, 0, "test_span") + result = self.rs._root._root.should_sample(None, 0, "test_span") # pylint: disable=not-context-manager # Verify debug log was called mock_logger.debug.assert_called_once_with("Rule cache is expired so using fallback sampling strategy") diff --git a/eachdist.ini b/eachdist.ini index c7427b1e9..285983c30 100644 --- a/eachdist.ini +++ b/eachdist.ini @@ -4,7 +4,7 @@ [lintroots] extraroots=scripts/, sample-applications/, contract-tests/images/ -subglob=*.py,tests/,test/,src/*, simple-client-server, applications/django/** +subglob=*.py,src/*, simple-client-server, applications/django/** [testroots] extraroots=tests/ diff --git a/tox.ini b/tox.ini index f9d0e2120..72ed6fc8e 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,9 @@ deps = -c dev-requirements.txt test: pytest test: pytest-cov + test: fastapi + test: starlette + test: flask setenv = ; TODO: The two repos branches need manual updated over time, need to figure out a more sustainable solution. From 0a50295e03406af1e5047c97f7eaf96a22fb5cb5 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Fri, 3 Oct 2025 22:40:29 -0700 Subject: [PATCH 6/8] add opentelemetry test dependency --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 72ed6fc8e..def5dc92e 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,7 @@ commands_pre = test: pip install "opentelemetry-sdk[test] @ {env:CORE_REPO}#egg=opentelemetry-sdk&subdirectory=opentelemetry-sdk" test: pip install "opentelemetry-instrumentation[test] @ {env:CONTRIB_REPO}#egg=opentelemetry-instrumentation&subdirectory=opentelemetry-instrumentation" test: pip install "opentelemetry-exporter-otlp[test] @ {env:CORE_REPO}#egg=opentelemetry-exporter-otlp&subdirectory=exporter/opentelemetry-exporter-otlp" + test: pip install "opentelemetry-test-utils @ {env:CORE_REPO}#egg=opentelemetry-test-utils&subdirectory=tests/opentelemetry-test-utils" aws-opentelemetry-distro: pip install {toxinidir}/aws-opentelemetry-distro commands = From 8b8a79f96f26b49d6582bb5e516a2ab255d15635 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Sun, 5 Oct 2025 18:31:30 -0700 Subject: [PATCH 7/8] update comments --- .flake8 | 1 - .../src/amazon/opentelemetry/distro/patches/_fastapi_patches.py | 2 +- .../src/amazon/opentelemetry/distro/patches/_flask_patches.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.flake8 b/.flake8 index 10d984636..cb1b3d122 100644 --- a/.flake8 +++ b/.flake8 @@ -26,4 +26,3 @@ exclude = mock_collector_service_pb2_grpc.py lambda-layer/terraform/lambda/.terraform lambda-layer/sample-apps/build - samples diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py index 68d56c79b..6f5840635 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py @@ -14,7 +14,7 @@ def _apply_fastapi_instrumentation_patches() -> None: Applies patches to provide code attributes support for FastAPI instrumentation. This patches the Flask instrumentation to automatically add code attributes - to spans by decorating view functions with current_span_code_attributes. + to spans by decorating view functions with record_code_attributes. """ if get_code_correlation_enabled_status() is True: _apply_fastapi_code_attributes_patch() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py index d1067081a..c975b8af1 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_flask_patches.py @@ -14,7 +14,7 @@ def _apply_flask_instrumentation_patches() -> None: Applies patches to provide code attributes support for Flask instrumentation. This patches the Flask instrumentation to automatically add code attributes - to spans by decorating view functions with current_span_code_attributes. + to spans by decorating view functions with record_code_attributes. """ if get_code_correlation_enabled_status() is True: _apply_flask_code_attributes_patch() From 9ac3321524e05c6144036f9d71402a42ba226376 Mon Sep 17 00:00:00 2001 From: wangzlei Date: Mon, 6 Oct 2025 18:59:42 -0700 Subject: [PATCH 8/8] fix comments --- .flake8 | 2 -- .../src/amazon/opentelemetry/distro/patches/_fastapi_patches.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.flake8 b/.flake8 index cb1b3d122..e55e479dd 100644 --- a/.flake8 +++ b/.flake8 @@ -24,5 +24,3 @@ exclude = mock_collector_service_pb2.py mock_collector_service_pb2.pyi mock_collector_service_pb2_grpc.py - lambda-layer/terraform/lambda/.terraform - lambda-layer/sample-apps/build diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py index 6f5840635..1f98fc58d 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_fastapi_patches.py @@ -13,7 +13,7 @@ def _apply_fastapi_instrumentation_patches() -> None: """FastAPI instrumentation patches Applies patches to provide code attributes support for FastAPI instrumentation. - This patches the Flask instrumentation to automatically add code attributes + This patches the FastAPI instrumentation to automatically add code attributes to spans by decorating view functions with record_code_attributes. """ if get_code_correlation_enabled_status() is True: