From c34f3b92d76892863054bebae79ba584baeef898 Mon Sep 17 00:00:00 2001 From: liustve Date: Wed, 18 Jun 2025 23:17:16 +0000 Subject: [PATCH 1/2] add logs pipeline --- .../distro/aws_opentelemetry_configurator.py | 15 +- .../logs/aws_batch_log_record_processor.py | 160 ++++++++++++ .../otlp/aws/logs/otlp_aws_logs_exporter.py | 161 +++++++++++- .../otlp/aws/common/test_aws_auth_session.py | 63 +++++ .../aws_batch_log_record_processor_test.py | 236 ++++++++++++++++++ .../aws/logs/otlp_aws_logs_exporter_test.py | 180 +++++++++++++ 6 files changed, 810 insertions(+), 5 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py index a08374bbe..b21bc6151 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py @@ -4,7 +4,7 @@ import os import re from logging import NOTSET, Logger, getLogger -from typing import ClassVar, Dict, List, Type, Union +from typing import ClassVar, Dict, List, Optional, Type, Union from importlib_metadata import version from typing_extensions import override @@ -22,6 +22,7 @@ AwsMetricAttributesSpanExporterBuilder, ) from amazon.opentelemetry.distro.aws_span_metrics_processor_builder import AwsSpanMetricsProcessorBuilder +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor import AwsBatchLogRecordProcessor from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter from amazon.opentelemetry.distro.otlp_udp_exporter import OTLPUdpSpanExporter @@ -181,7 +182,9 @@ def _init_logging( # Provides a default OTLP log exporter when none is specified. # This is the behavior for the logs exporters for other languages. - if not exporters: + logs_exporter = os.environ.get("OTEL_LOGS_EXPORTER") + + if not exporters and logs_exporter and (logs_exporter.lower() != "none"): exporters = {"otlp": OTLPLogExporter} provider = LoggerProvider(resource=resource) @@ -190,7 +193,11 @@ def _init_logging( for _, exporter_class in exporters.items(): exporter_args: Dict[str, any] = {} log_exporter = _customize_logs_exporter(exporter_class(**exporter_args), resource) - provider.add_log_record_processor(BatchLogRecordProcessor(exporter=log_exporter)) + + if isinstance(log_exporter, OTLPAwsLogExporter) and is_agent_observability_enabled(): + provider.add_log_record_processor(AwsBatchLogRecordProcessor(exporter=log_exporter)) + else: + provider.add_log_record_processor(BatchLogRecordProcessor(exporter=log_exporter)) handler = LoggingHandler(level=NOTSET, logger_provider=provider) @@ -532,7 +539,7 @@ def _is_lambda_environment(): return AWS_LAMBDA_FUNCTION_NAME_CONFIG in os.environ -def _is_aws_otlp_endpoint(otlp_endpoint: str = None, service: str = "xray") -> bool: +def _is_aws_otlp_endpoint(otlp_endpoint: Optional[str] = None, service: str = "xray") -> bool: """Is the given endpoint an AWS OTLP endpoint?""" pattern = AWS_TRACES_OTLP_ENDPOINT_PATTERN if service == "xray" else AWS_LOGS_OTLP_ENDPOINT_PATTERN diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py new file mode 100644 index 000000000..8feada9a0 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py @@ -0,0 +1,160 @@ +import logging +from typing import Mapping, Optional, Sequence, cast + +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter +from opentelemetry.context import ( + _SUPPRESS_INSTRUMENTATION_KEY, + attach, + detach, + set_value, +) +from opentelemetry.sdk._logs import LogData +from opentelemetry.sdk._logs._internal.export import BatchLogExportStrategy +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.util.types import AnyValue + +_logger = logging.getLogger(__name__) + + +class AwsBatchLogRecordProcessor(BatchLogRecordProcessor): + _BASE_LOG_BUFFER_BYTE_SIZE = ( + 2000 # Buffer size in bytes to account for log metadata not included in the body size calculation + ) + _MAX_LOG_REQUEST_BYTE_SIZE = ( + 1048576 # https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch-OTLPEndpoint.html + ) + + def __init__( + self, + exporter: OTLPAwsLogExporter, + schedule_delay_millis: Optional[float] = None, + max_export_batch_size: Optional[int] = None, + export_timeout_millis: Optional[float] = None, + max_queue_size: Optional[int] = None, + ): + + super().__init__( + exporter=exporter, + schedule_delay_millis=schedule_delay_millis, + max_export_batch_size=max_export_batch_size, + export_timeout_millis=export_timeout_millis, + max_queue_size=max_queue_size, + ) + + self._exporter = exporter + + # https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-sdk/src/opentelemetry/sdk/_shared_internal/__init__.py#L143 + def _export(self, batch_strategy: BatchLogExportStrategy) -> None: + """ + Preserves existing batching behavior but will intermediarly export small log batches if + the size of the data in the batch is at orabove AWS CloudWatch's maximum request size limit of 1 MB. + + - Data size of exported batches will ALWAYS be <= 1 MB except for the case below: + - If the data size of an exported batch is ever > 1 MB then the batch size is guaranteed to be 1 + """ + with self._export_lock: + iteration = 0 + # We could see concurrent export calls from worker and force_flush. We call _should_export_batch + # once the lock is obtained to see if we still need to make the requested export. + while self._should_export_batch(batch_strategy, iteration): + iteration += 1 + token = attach(set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)) + try: + batch_length = min(self._max_export_batch_size, len(self._queue)) + batch_data_size = 0 + batch = [] + + for _ in range(batch_length): + log_data: LogData = self._queue.pop() + log_size = self._BASE_LOG_BUFFER_BYTE_SIZE + self._get_any_value_size(log_data.log_record.body) + + if batch and (batch_data_size + log_size > self._MAX_LOG_REQUEST_BYTE_SIZE): + # if batch_data_size > MAX_LOG_REQUEST_BYTE_SIZE then len(batch) == 1 + if batch_data_size > self._MAX_LOG_REQUEST_BYTE_SIZE: + if self._is_gen_ai_log(batch[0]): + self._exporter.set_gen_ai_log_flag() + + self._exporter.export(batch) + batch_data_size = 0 + batch = [] + + batch_data_size += log_size + batch.append(log_data) + + if batch: + # if batch_data_size > MAX_LOG_REQUEST_BYTE_SIZE then len(batch) == 1 + if batch_data_size > self._MAX_LOG_REQUEST_BYTE_SIZE: + if self._is_gen_ai_log(batch[0]): + self._exporter.set_gen_ai_log_flag() + + self._exporter.export(batch) + except Exception as e: # pylint: disable=broad-exception-caught + _logger.exception("Exception while exporting logs: " + str(e)) + detach(token) + + def _get_any_value_size(self, val: AnyValue, depth: int = 3) -> int: + """ + Only used to indicate whether we should export a batch log size of 1 or not. + Calculates the size in bytes of an AnyValue object. + Will processs complex AnyValue structures up to the specified depth limit. + If the depth limit of the AnyValue structure is exceeded, returns 0. + + Args: + val: The AnyValue object to calculate size for + depth: Maximum depth to traverse in nested structures (default: 3) + + Returns: + int: Total size of the AnyValue object in bytes + """ + # Use a stack to prevent excessive recursive calls. + stack = [(val, 0)] + size: int = 0 + + while stack: + # small optimization. We can stop calculating the size once it reaches the 1 MB limit. + if size >= self._MAX_LOG_REQUEST_BYTE_SIZE: + return size + + next_val, current_depth = stack.pop() + + if isinstance(next_val, (str, bytes)): + size += len(next_val) + continue + + if isinstance(next_val, bool): + size += 4 if next_val else 5 + continue + + if isinstance(next_val, (float, int)): + size += len(str(next_val)) + continue + + if current_depth <= depth: + if isinstance(next_val, Sequence): + for content in next_val: + stack.append((cast(AnyValue, content), current_depth + 1)) + + if isinstance(next_val, Mapping): + for key, content in next_val.items(): + size += len(key) + stack.append((content, current_depth + 1)) + else: + _logger.debug("Max log depth exceeded. Log data size will not be accurately calculated.") + return 0 + + return size + + @staticmethod + def _is_gen_ai_log(log_data: LogData) -> bool: + """ + Is the log a Gen AI log event? + """ + gen_ai_instrumentations = { + "openinference.instrumentation.langchain", + "openinference.instrumentation.crewai", + "opentelemetry.instrumentation.langchain", + "crewai.telemetry", + "openlit.otel.tracing", + } + + return log_data.instrumentation_scope.name in gen_ai_instrumentations diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py index 048632c06..64203b434 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py @@ -1,14 +1,41 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +import gzip +import logging +from io import BytesIO +from time import sleep +from typing import Dict, Optional, Sequence + +import requests from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession +from opentelemetry.exporter.otlp.proto.common._internal import ( + _create_exp_backoff_generator, +) +from opentelemetry.exporter.otlp.proto.common._log_encoder import encode_logs from opentelemetry.exporter.otlp.proto.http import Compression from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +from opentelemetry.sdk._logs import ( + LogData, +) +from opentelemetry.sdk._logs.export import ( + LogExportResult, +) + +_logger = logging.getLogger(__name__) class OTLPAwsLogExporter(OTLPLogExporter): + _LARGE_LOG_HEADER = "x-aws-truncatable-fields" + _LARGE_GEN_AI_LOG_PATH_HEADER = ( + "\\$['resourceLogs'][0]['scopeLogs'][0]['logRecords'][0]['body']" + "['kvlistValue']['values'][*]['value']['kvlistValue']['values'][*]" + "['value']['arrayValue']['values'][*]['kvlistValue']['values'][*]" + "['value']['stringValue']" + ) + _RETRY_AFTER_HEADER = "Retry-After" # https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling + def __init__( self, endpoint: Optional[str] = None, @@ -18,6 +45,7 @@ def __init__( headers: Optional[Dict[str, str]] = None, timeout: Optional[int] = None, ): + self._gen_ai_log_flag = False self._aws_region = None if endpoint: @@ -34,3 +62,134 @@ def __init__( compression=Compression.Gzip, session=AwsAuthSession(aws_region=self._aws_region, service="logs"), ) + + # https://github.com/open-telemetry/opentelemetry-python/blob/main/exporter/opentelemetry-exporter-otlp-proto-http/src/opentelemetry/exporter/otlp/proto/http/_log_exporter/__init__.py#L167 + def export(self, batch: Sequence[LogData]) -> LogExportResult: + """ + Exports the given batch of OTLP log data. + Behaviors of how this export will work - + + 1. Always compresses the serialized data into gzip before sending. + + 2. If self._gen_ai_log_flag is enabled, the log data is > 1 MB a + and the assumption is that the log is a normalized gen.ai LogEvent. + - inject the {LARGE_LOG_HEADER} into the header. + + 3. Retry behavior is now the following: + - if the response contains a status code that is retryable and the response contains Retry-After in its + headers, the serialized data will be exported after that set delay + + - if the response does not contain that Retry-After header, default back to the current iteration of the + exponential backoff delay + """ + + if self._shutdown: + _logger.warning("Exporter already shutdown, ignoring batch") + return LogExportResult.FAILURE + + serialized_data = encode_logs(batch).SerializeToString() + + gzip_data = BytesIO() + with gzip.GzipFile(fileobj=gzip_data, mode="w") as gzip_stream: + gzip_stream.write(serialized_data) + + data = gzip_data.getvalue() + + backoff = _create_exp_backoff_generator(max_value=self._MAX_RETRY_TIMEOUT) + + while True: + resp = self._send(data) + + if resp.ok: + return LogExportResult.SUCCESS + + if not self._retryable(resp): + _logger.error( + "Failed to export logs batch code: %s, reason: %s", + resp.status_code, + resp.text, + ) + self._gen_ai_log_flag = False + return LogExportResult.FAILURE + + # https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling + maybe_retry_after = resp.headers.get(self._RETRY_AFTER_HEADER, None) + + # Set the next retry delay to the value of the Retry-After response in the headers. + # If Retry-After is not present in the headers, default to the next iteration of the + # exponential backoff strategy. + + delay = self._parse_retryable_header(maybe_retry_after) + + if delay == -1: + delay = next(backoff, self._MAX_RETRY_TIMEOUT) + + if delay == self._MAX_RETRY_TIMEOUT: + _logger.error( + "Transient error %s encountered while exporting logs batch. " + "No Retry-After header found and all backoff retries exhausted. " + "Logs will not be exported.", + resp.reason, + ) + self._gen_ai_log_flag = False + return LogExportResult.FAILURE + + _logger.warning( + "Transient error %s encountered while exporting logs batch, retrying in %ss.", + resp.reason, + delay, + ) + + sleep(delay) + + def set_gen_ai_log_flag(self): + """ + Sets a flag that indicates the current log batch contains + a generative AI log record that exceeds the CloudWatch Logs size limit (1MB). + """ + self._gen_ai_log_flag = True + + def _send(self, serialized_data: bytes): + try: + response = self._session.post( + url=self._endpoint, + headers={self._LARGE_LOG_HEADER: self._LARGE_GEN_AI_LOG_PATH_HEADER} if self._gen_ai_log_flag else None, + data=serialized_data, + verify=self._certificate_file, + timeout=self._timeout, + cert=self._client_cert, + ) + return response + except ConnectionError: + response = self._session.post( + url=self._endpoint, + headers={self._LARGE_LOG_HEADER: self._LARGE_GEN_AI_LOG_PATH_HEADER} if self._gen_ai_log_flag else None, + data=serialized_data, + verify=self._certificate_file, + timeout=self._timeout, + cert=self._client_cert, + ) + return response + + @staticmethod + def _retryable(resp: requests.Response) -> bool: + """ + Is it a retryable response? + """ + + return resp.status_code in (429, 503) or OTLPLogExporter._retryable(resp) + + @staticmethod + def _parse_retryable_header(retry_header: Optional[str]) -> float: + """ + Converts the given retryable header into a delay in seconds, returns -1 if there's no header + or error with the parsing + """ + if not retry_header: + return -1 + + try: + val = float(retry_header) + return val if val >= 0 else -1 + except ValueError: + return -1 diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py new file mode 100644 index 000000000..e0c62b89d --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py @@ -0,0 +1,63 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest import TestCase +from unittest.mock import patch + +import requests +from botocore.credentials import Credentials + +from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession + +AWS_OTLP_TRACES_ENDPOINT = "https://xray.us-east-1.amazonaws.com/v1/traces" +AWS_OTLP_LOGS_ENDPOINT = "https://logs.us-east-1.amazonaws.com/v1/logs" + +AUTHORIZATION_HEADER = "Authorization" +X_AMZ_DATE_HEADER = "X-Amz-Date" +X_AMZ_SECURITY_TOKEN_HEADER = "X-Amz-Security-Token" + +mock_credentials = Credentials(access_key="test_access_key", secret_key="test_secret_key", token="test_session_token") + + +class TestAwsAuthSession(TestCase): + @patch("pkg_resources.get_distribution", side_effect=ImportError("test error")) + @patch.dict("sys.modules", {"botocore": None}, clear=False) + @patch("requests.Session.request", return_value=requests.Response()) + def test_aws_auth_session_no_botocore(self, _, __): + """Tests that aws_auth_session will not inject SigV4 Headers if botocore is not installed.""" + + session = AwsAuthSession("us-east-1", "xray") + actual_headers = {"test": "test"} + + session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) + + self.assertNotIn(AUTHORIZATION_HEADER, actual_headers) + self.assertNotIn(X_AMZ_DATE_HEADER, actual_headers) + self.assertNotIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) + + @patch("requests.Session.request", return_value=requests.Response()) + @patch("botocore.session.Session.get_credentials", return_value=None) + def test_aws_auth_session_no_credentials(self, _, __): + """Tests that aws_auth_session will not inject SigV4 Headers if retrieving credentials returns None.""" + + session = AwsAuthSession("us-east-1", "xray") + actual_headers = {"test": "test"} + + session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) + + self.assertNotIn(AUTHORIZATION_HEADER, actual_headers) + self.assertNotIn(X_AMZ_DATE_HEADER, actual_headers) + self.assertNotIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) + + @patch("requests.Session.request", return_value=requests.Response()) + @patch("botocore.session.Session.get_credentials", return_value=mock_credentials) + def test_aws_auth_session(self, _, __): + """Tests that aws_auth_session will inject SigV4 Headers if botocore is installed.""" + + session = AwsAuthSession("us-east-1", "xray") + actual_headers = {"test": "test"} + + session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) + + self.assertIn(AUTHORIZATION_HEADER, actual_headers) + self.assertIn(X_AMZ_DATE_HEADER, actual_headers) + self.assertIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py new file mode 100644 index 000000000..1abf680f1 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py @@ -0,0 +1,236 @@ +import time +import unittest +from typing import List +from unittest.mock import MagicMock, patch + +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor import ( + AwsBatchLogRecordProcessor, + BatchLogExportStrategy, +) +from opentelemetry._logs.severity import SeverityNumber +from opentelemetry.sdk._logs import LogData, LogRecord +from opentelemetry.sdk._logs.export import LogExportResult +from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.trace import TraceFlags +from opentelemetry.util.types import AnyValue + + +class TestAwsBatchLogRecordProcessor(unittest.TestCase): + + def setUp(self): + self.mock_exporter = MagicMock() + self.mock_exporter.export.return_value = LogExportResult.SUCCESS + + self.processor = AwsBatchLogRecordProcessor(exporter=self.mock_exporter) + + def test_process_log_data_nested_structure(self): + """Tests that the processor correctly handles nested structures (dict/list)""" + message_size = 400 + depth = 2 + + nested_dict_log_body = self.generate_nested_log_body( + depth=depth, expected_body="X" * message_size, create_map=True + ) + nested_array_log_body = self.generate_nested_log_body( + depth=depth, expected_body="X" * message_size, create_map=False + ) + + dict_size = self.processor._get_any_value_size(val=nested_dict_log_body, depth=depth) + array_size = self.processor._get_any_value_size(val=nested_array_log_body, depth=depth) + + # Asserting almost equal to account for key lengths in the Log object body + self.assertAlmostEqual(dict_size, message_size, delta=20) + self.assertAlmostEqual(array_size, message_size, delta=20) + + def test_process_log_data_nested_structure_exceeds_depth(self): + """Tests that the processor returns 0 for nested structure that exceeds the depth limit""" + message_size = 400 + log_body = "X" * message_size + + nested_dict_log_body = self.generate_nested_log_body(depth=4, expected_body=log_body, create_map=True) + nested_array_log_body = self.generate_nested_log_body(depth=4, expected_body=log_body, create_map=False) + + dict_size = self.processor._get_any_value_size(val=nested_dict_log_body, depth=3) + array_size = self.processor._get_any_value_size(val=nested_array_log_body, depth=3) + + self.assertEqual(dict_size, 0) + self.assertEqual(array_size, 0) + + def test_process_log_data_nested_structure_size_exceeds_max_log_size(self): + """Tests that the processor returns prematurely if the size already exceeds _MAX_LOG_REQUEST_BYTE_SIZE""" + log_body = { + "smallKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE // 2), + "bigKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1), + } + + nested_dict_log_body = self.generate_nested_log_body(depth=0, expected_body=log_body, create_map=True) + nested_array_log_body = self.generate_nested_log_body(depth=0, expected_body=log_body, create_map=False) + + dict_size = self.processor._get_any_value_size(val=nested_dict_log_body) + array_size = self.processor._get_any_value_size(val=nested_array_log_body) + + self.assertAlmostEqual(dict_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE, delta=20) + self.assertAlmostEqual(array_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE, delta=20) + + def test_process_log_data_primitive(self): + + primitives: List[AnyValue] = ["test", b"test", 1, 1.2, True, False, None] + expected_sizes = [4, 4, 1, 3, 4, 5, 0] + + for i in range(len(primitives)): + body = primitives[i] + expected_size = expected_sizes[i] + + actual_size = self.processor._get_any_value_size(body) + self.assertEqual(actual_size, expected_size) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") + def test_export_single_batch_under_size_limit(self, _, __, ___): + """Tests that export is only called once if a single batch is under the size limit""" + log_count = 10 + log_body = "test" + test_logs = self.generate_test_log_data(count=log_count, log_body=log_body) + total_data_size = 0 + + for log in test_logs: + size = self.processor._get_any_value_size(log.log_record.body) + total_data_size += size + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + args, _ = self.mock_exporter.export.call_args + actual_batch = args[0] + + self.assertLess(total_data_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE) + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(len(actual_batch), log_count) + self.mock_exporter.export.assert_called_once() + self.mock_exporter.set_gen_ai_log_flag.assert_not_called() + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") + def test_export_single_batch_all_logs_over_size_limit(self, _, __, ___): + """Should make multiple export calls of batch size 1 to export logs of size > 1 MB. + But should only call set_gen_ai_log_flag if it's a Gen AI log event.""" + + large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) + non_gen_ai_test_logs = self.generate_test_log_data(count=3, log_body=large_log_body) + gen_ai_test_logs = [] + + gen_ai_scopes = [ + "openinference.instrumentation.langchain", + "openinference.instrumentation.crewai", + "opentelemetry.instrumentation.langchain", + "crewai.telemetry", + "openlit.otel.tracing", + ] + + for gen_ai_scope in gen_ai_scopes: + gen_ai_test_logs.extend( + self.generate_test_log_data( + count=1, log_body=large_log_body, instrumentation_scope=InstrumentationScope(gen_ai_scope, "1.0.0") + ) + ) + + test_logs = gen_ai_test_logs + non_gen_ai_test_logs + + for log in test_logs: + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(self.mock_exporter.export.call_count, 3 + len(gen_ai_test_logs)) + self.assertEqual(self.mock_exporter.set_gen_ai_log_flag.call_count, len(gen_ai_test_logs)) + + batches = self.mock_exporter.export.call_args_list + + for batch in batches: + self.assertEqual(len(batch[0]), 1) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") + def test_export_single_batch_some_logs_over_size_limit(self, _, __, ___): + """Should make calls to export smaller sub-batch logs""" + large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) + gen_ai_scope = InstrumentationScope("openinference.instrumentation.langchain", "1.0.0") + small_log_body = "X" * ( + int(self.processor._MAX_LOG_REQUEST_BYTE_SIZE / 10) - self.processor._BASE_LOG_BUFFER_BYTE_SIZE + ) + test_logs = self.generate_test_log_data(count=3, log_body=large_log_body, instrumentation_scope=gen_ai_scope) + # 1st, 2nd, 3rd batch = size 1 + # 4th batch = size 10 + # 5th batch = size 2 + small_logs = self.generate_test_log_data(count=12, log_body=small_log_body, instrumentation_scope=gen_ai_scope) + + test_logs.extend(small_logs) + + for log in test_logs: + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(self.mock_exporter.export.call_count, 5) + self.assertEqual(self.mock_exporter.set_gen_ai_log_flag.call_count, 3) + + batches = self.mock_exporter.export.call_args_list + + expected_sizes = { + 0: 1, # 1st batch (index 1) should have 1 log + 1: 1, # 2nd batch (index 1) should have 1 log + 2: 1, # 3rd batch (index 2) should have 1 log + 3: 10, # 4th batch (index 3) should have 10 logs + 4: 2, # 5th batch (index 4) should have 2 logs + } + + for i, call in enumerate(batches): + batch = call[0][0] + expected_size = expected_sizes[i] + self.assertEqual(len(batch), expected_size) + + def generate_test_log_data( + self, log_body: AnyValue, count=5, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0") + ) -> List[LogData]: + logs = [] + for i in range(count): + record = LogRecord( + timestamp=int(time.time_ns()), + trace_id=int(f"0x{i + 1:032x}", 16), + span_id=int(f"0x{i + 1:016x}", 16), + trace_flags=TraceFlags(1), + severity_text="INFO", + severity_number=SeverityNumber.INFO, + body=log_body, + attributes={"test.attribute": f"value-{i + 1}"}, + ) + + log_data = LogData(log_record=record, instrumentation_scope=instrumentation_scope) + logs.append(log_data) + + return logs + + @staticmethod + def generate_nested_log_body(depth=0, expected_body: AnyValue = "test", create_map=True): + if depth < 0: + return expected_body + + if create_map: + return { + "key": TestAwsBatchLogRecordProcessor.generate_nested_log_body(depth - 1, expected_body, create_map) + } + + return [TestAwsBatchLogRecordProcessor.generate_nested_log_body(depth - 1, expected_body, create_map)] diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py new file mode 100644 index 000000000..9f6d84b32 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py @@ -0,0 +1,180 @@ +import time +from unittest import TestCase +from unittest.mock import patch + +import requests +from requests.structures import CaseInsensitiveDict + +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter +from opentelemetry._logs.severity import SeverityNumber +from opentelemetry.sdk._logs import LogData, LogRecord +from opentelemetry.sdk._logs.export import ( + LogExportResult, +) +from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.trace import TraceFlags + + +class TestOTLPAwsLogsExporter(TestCase): + _ENDPOINT = "https://logs.us-west-2.amazonaws.com/v1/logs" + good_response = requests.Response() + good_response.status_code = 200 + + non_retryable_response = requests.Response() + non_retryable_response.status_code = 404 + + retryable_response_no_header = requests.Response() + retryable_response_no_header.status_code = 429 + + retryable_response_header = requests.Response() + retryable_response_header.headers = CaseInsensitiveDict({"Retry-After": "10"}) + retryable_response_header.status_code = 503 + + retryable_response_bad_header = requests.Response() + retryable_response_bad_header.headers = CaseInsensitiveDict({"Retry-After": "-12"}) + retryable_response_bad_header.status_code = 503 + + def setUp(self): + self.logs = self.generate_test_log_data() + self.exporter = OTLPAwsLogExporter(endpoint=self._ENDPOINT) + + @patch("requests.Session.request", return_value=good_response) + def test_export_success(self, mock_request): + """Tests that the exporter always compresses the serialized logs with gzip before exporting.""" + result = self.exporter.export(self.logs) + + mock_request.assert_called_once() + + _, kwargs = mock_request.call_args + data = kwargs.get("data", None) + + self.assertEqual(result, LogExportResult.SUCCESS) + + # Gzip first 10 bytes are reserved for metadata headers: + # https://www.loc.gov/preservation/digital/formats/fdd/fdd000599.shtml?loclr=blogsig + self.assertIsNotNone(data) + self.assertTrue(len(data) >= 10) + self.assertEqual(data[0:2], b"\x1f\x8b") + + @patch("requests.Session.request", return_value=good_response) + def test_export_gen_ai_logs(self, mock_request): + """Tests that when set_gen_ai_log_flag is set, the exporter includes the LLO header in the request.""" + + self.exporter.set_gen_ai_log_flag() + + result = self.exporter.export(self.logs) + + mock_request.assert_called_once() + + _, kwargs = mock_request.call_args + headers = kwargs.get("headers", None) + + self.assertEqual(result, LogExportResult.SUCCESS) + self.assertIsNotNone(headers) + self.assertIn(self.exporter._LARGE_LOG_HEADER, headers) + self.assertEqual(headers[self.exporter._LARGE_LOG_HEADER], self.exporter._LARGE_GEN_AI_LOG_PATH_HEADER) + + @patch("requests.Session.request", return_value=good_response) + def test_should_not_export_if_shutdown(self, mock_request): + """Tests that no export request is made if the exporter is shutdown.""" + self.exporter.shutdown() + result = self.exporter.export(self.logs) + + mock_request.assert_not_called() + self.assertEqual(result, LogExportResult.FAILURE) + + @patch("requests.Session.request", return_value=non_retryable_response) + def test_should_not_export_again_if_not_retryable(self, mock_request): + """Tests that only one export request is made if the response status code is non-retryable.""" + result = self.exporter.export(self.logs) + mock_request.assert_called_once() + + self.assertEqual(result, LogExportResult.FAILURE) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None + ) + @patch("requests.Session.request", return_value=retryable_response_no_header) + def test_should_export_again_with_backoff_if_retryable_and_no_retry_after_header(self, mock_request, mock_sleep): + """Tests that multiple export requests are made with exponential delay if the response status code is retryable. + But there is no Retry-After header.""" + result = self.exporter.export(self.logs) + + # 1, 2, 4, 8, 16, 32 delays + self.assertEqual(mock_sleep.call_count, 6) + + delays = mock_sleep.call_args_list + + for i in range(len(delays)): + self.assertEqual(delays[i][0][0], 2**i) + + # Number of calls: 1 + len(1, 2, 4, 8, 16, 32 delays) + self.assertEqual(mock_request.call_count, 7) + self.assertEqual(result, LogExportResult.FAILURE) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None + ) + @patch( + "requests.Session.request", + side_effect=[retryable_response_header, retryable_response_header, retryable_response_header, good_response], + ) + def test_should_export_again_with_server_delay_if_retryable_and_retry_after_header(self, mock_request, mock_sleep): + """Tests that multiple export requests are made with the server's suggested + delay if the response status code is retryable and there is a Retry-After header.""" + result = self.exporter.export(self.logs) + delays = mock_sleep.call_args_list + + for i in range(len(delays)): + self.assertEqual(delays[i][0][0], 10) + + self.assertEqual(mock_sleep.call_count, 3) + self.assertEqual(mock_request.call_count, 4) + self.assertEqual(result, LogExportResult.SUCCESS) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None + ) + @patch( + "requests.Session.request", + side_effect=[ + retryable_response_bad_header, + retryable_response_bad_header, + retryable_response_bad_header, + good_response, + ], + ) + def test_should_export_again_with_backoff_delay_if_retryable_and_bad_retry_after_header( + self, mock_request, mock_sleep + ): + """Tests that multiple export requests are made with exponential delay if the response status code is retryable. + but the Retry-After header ins invalid or malformed.""" + result = self.exporter.export(self.logs) + delays = mock_sleep.call_args_list + + for i in range(len(delays)): + self.assertEqual(delays[i][0][0], 2**i) + + self.assertEqual(mock_sleep.call_count, 3) + self.assertEqual(mock_request.call_count, 4) + self.assertEqual(result, LogExportResult.SUCCESS) + + def generate_test_log_data(self, count=5): + logs = [] + for i in range(count): + record = LogRecord( + timestamp=int(time.time_ns()), + trace_id=int(f"0x{i + 1:032x}", 16), + span_id=int(f"0x{i + 1:016x}", 16), + trace_flags=TraceFlags(1), + severity_text="INFO", + severity_number=SeverityNumber.INFO, + body=f"Test log {i + 1}", + attributes={"test.attribute": f"value-{i + 1}"}, + ) + + log_data = LogData(log_record=record, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0")) + + logs.append(log_data) + + return logs From 09c9006ca039d3599eaf07e99ee4e474a2b97b40 Mon Sep 17 00:00:00 2001 From: liustve Date: Thu, 3 Jul 2025 17:17:47 +0000 Subject: [PATCH 2/2] cleanup botocore usage --- .github/workflows/release-lambda.yml | 4 +- aws-opentelemetry-distro/pyproject.toml | 108 +-- .../distro/_aws_attribute_keys.py | 1 + .../distro/_aws_metric_attribute_generator.py | 2 +- .../distro/_aws_span_processing_util.py | 10 - .../src/amazon/opentelemetry/distro/_utils.py | 57 +- .../distro/aws_opentelemetry_configurator.py | 296 ++++++-- .../distro/aws_opentelemetry_distro.py | 57 +- .../aws/metrics/_cloudwatch_log_client.py | 380 ++++++++++ .../metrics/aws_cloudwatch_emf_exporter.py | 631 +++++++++++++++++ .../otlp/aws/common/aws_auth_session.py | 72 +- ..._aws_cw_otlp_batch_log_record_processor.py | 258 +++++++ .../logs/aws_batch_log_record_processor.py | 160 ----- .../otlp/aws/logs/otlp_aws_logs_exporter.py | 164 +++-- .../otlp/aws/traces/otlp_aws_span_exporter.py | 56 +- .../opentelemetry/distro/llo_handler.py | 557 +++++++++++++++ .../distro/patches/_bedrock_patches.py | 224 +----- .../distro/patches/_botocore_patches.py | 26 +- .../test_aws_cloudwatch_emf_exporter.py | 625 +++++++++++++++++ .../aws/metrics/test_cloudwatch_log_client.py | 584 ++++++++++++++++ .../otlp/aws/common/test_aws_auth_session.py | 20 +- .../aws_batch_log_record_processor_test.py | 236 ------- .../aws/logs/otlp_aws_logs_exporter_test.py | 180 ----- ..._aws_cw_otlp_batch_log_record_processor.py | 310 +++++++++ .../aws/logs/test_otlp_aws_logs_exporter.py | 250 +++++++ .../aws/traces/test_otlp_aws_span_exporter.py | 196 ++++++ .../llo_handler/test_llo_handler_base.py | 57 ++ .../test_llo_handler_collection.py | 269 ++++++++ .../llo_handler/test_llo_handler_events.py | 651 ++++++++++++++++++ .../test_llo_handler_frameworks.py | 444 ++++++++++++ .../llo_handler/test_llo_handler_patterns.py | 112 +++ .../test_llo_handler_processing.py | 328 +++++++++ .../distro/test_aws_auth_session.py | 63 -- .../test_aws_metric_attribute_generator.py | 2 +- .../test_aws_opentelementry_configurator.py | 480 ++++++++++++- .../distro/test_aws_opentelemetry_distro.py | 219 +++++- .../distro/test_instrumentation_patch.py | 306 ++------ .../opentelemetry/distro/test_llo_handler.py | 40 ++ .../distro/test_otlp_aws_span_exporter.py | 29 - .../amazon/opentelemetry/distro/test_utils.py | 175 +++++ .../applications/botocore/botocore_server.py | 28 +- .../applications/botocore/requirements.txt | 2 - .../applications/django/requirements.txt | 2 - .../mysql-connector/requirements.txt | 2 - .../applications/mysqlclient/requirements.txt | 2 - .../applications/psycopg2/requirements.txt | 2 - .../applications/pymysql/requirements.txt | 2 - .../applications/requests/requirements.txt | 2 - .../images/mock-collector/pyproject.toml | 6 +- .../images/mock-collector/requirements.txt | 6 +- contract-tests/tests/pyproject.toml | 4 +- .../test/amazon/botocore/botocore_test.py | 41 +- 52 files changed, 7193 insertions(+), 1545 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/_cloudwatch_log_client.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/aws_cloudwatch_emf_exporter.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/_aws_cw_otlp_batch_log_record_processor.py delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/llo_handler.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_aws_cloudwatch_emf_exporter.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_cloudwatch_log_client.py delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_otlp_aws_logs_exporter.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/traces/test_otlp_aws_span_exporter.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_base.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_collection.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_events.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_frameworks.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_patterns.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_processing.py delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_auth_session.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_llo_handler.py delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_otlp_aws_span_exporter.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_utils.py diff --git a/.github/workflows/release-lambda.yml b/.github/workflows/release-lambda.yml index 0dd0d6344..d390b1836 100644 --- a/.github/workflows/release-lambda.yml +++ b/.github/workflows/release-lambda.yml @@ -9,10 +9,10 @@ on: aws_region: description: 'Deploy to aws regions' required: true - default: 'us-east-1, us-east-2, us-west-1, us-west-2, ap-south-1, ap-northeast-3, ap-northeast-2, ap-southeast-1, ap-southeast-2, ap-northeast-1, ca-central-1, eu-central-1, eu-west-1, eu-west-2, eu-west-3, eu-north-1, sa-east-1, af-south-1, ap-east-1, ap-south-2, ap-southeast-3, ap-southeast-4, eu-central-2, eu-south-1, eu-south-2, il-central-1, me-central-1, me-south-1' + default: 'us-east-1, us-east-2, us-west-1, us-west-2, ap-south-1, ap-northeast-3, ap-northeast-2, ap-southeast-1, ap-southeast-2, ap-northeast-1, ca-central-1, eu-central-1, eu-west-1, eu-west-2, eu-west-3, eu-north-1, sa-east-1, af-south-1, ap-east-1, ap-south-2, ap-southeast-3, ap-southeast-4, eu-central-2, eu-south-1, eu-south-2, il-central-1, me-central-1, me-south-1, ap-southeast-5, ap-southeast-7, mx-central-1, ca-west-1, cn-north-1, cn-northwest-1' env: - COMMERCIAL_REGIONS: us-east-1, us-east-2, us-west-1, us-west-2, ap-south-1, ap-northeast-3, ap-northeast-2, ap-southeast-1, ap-southeast-2, ap-northeast-1, ca-central-1, eu-central-1, eu-west-1, eu-west-2, eu-west-3, eu-north-1, sa-east-1 + COMMERCIAL_REGIONS: us-east-1, us-east-2, us-west-1, us-west-2, ap-south-1, ap-northeast-3, ap-northeast-2, ap-southeast-1, ap-southeast-2, ap-northeast-1, ca-central-1, eu-central-1, eu-west-1, eu-west-2, eu-west-3, eu-north-1, sa-east-1, ap-southeast-5, ap-southeast-7, mx-central-1, ca-west-1, cn-north-1, cn-northwest-1 LAYER_NAME: AWSOpenTelemetryDistroPython permissions: diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 3d8eadbc1..f8984854d 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -24,62 +24,62 @@ classifiers = [ ] dependencies = [ - "opentelemetry-api == 1.27.0", - "opentelemetry-sdk == 1.27.0", - "opentelemetry-exporter-otlp-proto-grpc == 1.27.0", - "opentelemetry-exporter-otlp-proto-http == 1.27.0", - "opentelemetry-propagator-b3 == 1.27.0", - "opentelemetry-propagator-jaeger == 1.27.0", - "opentelemetry-exporter-otlp-proto-common == 1.27.0", + "opentelemetry-api == 1.33.1", + "opentelemetry-sdk == 1.33.1", + "opentelemetry-exporter-otlp-proto-grpc == 1.33.1", + "opentelemetry-exporter-otlp-proto-http == 1.33.1", + "opentelemetry-propagator-b3 == 1.33.1", + "opentelemetry-propagator-jaeger == 1.33.1", + "opentelemetry-exporter-otlp-proto-common == 1.33.1", "opentelemetry-sdk-extension-aws == 2.0.2", "opentelemetry-propagator-aws-xray == 1.0.1", - "opentelemetry-distro == 0.48b0", - "opentelemetry-processor-baggage == 0.48b0", - "opentelemetry-propagator-ot-trace == 0.48b0", - "opentelemetry-instrumentation == 0.48b0", - "opentelemetry-instrumentation-aws-lambda == 0.48b0", - "opentelemetry-instrumentation-aio-pika == 0.48b0", - "opentelemetry-instrumentation-aiohttp-client == 0.48b0", - "opentelemetry-instrumentation-aiopg == 0.48b0", - "opentelemetry-instrumentation-asgi == 0.48b0", - "opentelemetry-instrumentation-asyncpg == 0.48b0", - "opentelemetry-instrumentation-boto == 0.48b0", - "opentelemetry-instrumentation-boto3sqs == 0.48b0", - "opentelemetry-instrumentation-botocore == 0.48b0", - "opentelemetry-instrumentation-celery == 0.48b0", - "opentelemetry-instrumentation-confluent-kafka == 0.48b0", - "opentelemetry-instrumentation-dbapi == 0.48b0", - "opentelemetry-instrumentation-django == 0.48b0", - "opentelemetry-instrumentation-elasticsearch == 0.48b0", - "opentelemetry-instrumentation-falcon == 0.48b0", - "opentelemetry-instrumentation-fastapi == 0.48b0", - "opentelemetry-instrumentation-flask == 0.48b0", - "opentelemetry-instrumentation-grpc == 0.48b0", - "opentelemetry-instrumentation-httpx == 0.48b0", - "opentelemetry-instrumentation-jinja2 == 0.48b0", - "opentelemetry-instrumentation-kafka-python == 0.48b0", - "opentelemetry-instrumentation-logging == 0.48b0", - "opentelemetry-instrumentation-mysql == 0.48b0", - "opentelemetry-instrumentation-mysqlclient == 0.48b0", - "opentelemetry-instrumentation-pika == 0.48b0", - "opentelemetry-instrumentation-psycopg2 == 0.48b0", - "opentelemetry-instrumentation-pymemcache == 0.48b0", - "opentelemetry-instrumentation-pymongo == 0.48b0", - "opentelemetry-instrumentation-pymysql == 0.48b0", - "opentelemetry-instrumentation-pyramid == 0.48b0", - "opentelemetry-instrumentation-redis == 0.48b0", - "opentelemetry-instrumentation-remoulade == 0.48b0", - "opentelemetry-instrumentation-requests == 0.48b0", - "opentelemetry-instrumentation-sqlalchemy == 0.48b0", - "opentelemetry-instrumentation-sqlite3 == 0.48b0", - "opentelemetry-instrumentation-starlette == 0.48b0", - "opentelemetry-instrumentation-system-metrics == 0.48b0", - "opentelemetry-instrumentation-tornado == 0.48b0", - "opentelemetry-instrumentation-tortoiseorm == 0.48b0", - "opentelemetry-instrumentation-urllib == 0.48b0", - "opentelemetry-instrumentation-urllib3 == 0.48b0", - "opentelemetry-instrumentation-wsgi == 0.48b0", - "opentelemetry-instrumentation-cassandra == 0.48b0", + "opentelemetry-distro == 0.54b1", + "opentelemetry-processor-baggage == 0.54b1", + "opentelemetry-propagator-ot-trace == 0.54b1", + "opentelemetry-instrumentation == 0.54b1", + "opentelemetry-instrumentation-aws-lambda == 0.54b1", + "opentelemetry-instrumentation-aio-pika == 0.54b1", + "opentelemetry-instrumentation-aiohttp-client == 0.54b1", + "opentelemetry-instrumentation-aiopg == 0.54b1", + "opentelemetry-instrumentation-asgi == 0.54b1", + "opentelemetry-instrumentation-asyncpg == 0.54b1", + "opentelemetry-instrumentation-boto == 0.54b1", + "opentelemetry-instrumentation-boto3sqs == 0.54b1", + "opentelemetry-instrumentation-botocore == 0.54b1", + "opentelemetry-instrumentation-celery == 0.54b1", + "opentelemetry-instrumentation-confluent-kafka == 0.54b1", + "opentelemetry-instrumentation-dbapi == 0.54b1", + "opentelemetry-instrumentation-django == 0.54b1", + "opentelemetry-instrumentation-elasticsearch == 0.54b1", + "opentelemetry-instrumentation-falcon == 0.54b1", + "opentelemetry-instrumentation-fastapi == 0.54b1", + "opentelemetry-instrumentation-flask == 0.54b1", + "opentelemetry-instrumentation-grpc == 0.54b1", + "opentelemetry-instrumentation-httpx == 0.54b1", + "opentelemetry-instrumentation-jinja2 == 0.54b1", + "opentelemetry-instrumentation-kafka-python == 0.54b1", + "opentelemetry-instrumentation-logging == 0.54b1", + "opentelemetry-instrumentation-mysql == 0.54b1", + "opentelemetry-instrumentation-mysqlclient == 0.54b1", + "opentelemetry-instrumentation-pika == 0.54b1", + "opentelemetry-instrumentation-psycopg2 == 0.54b1", + "opentelemetry-instrumentation-pymemcache == 0.54b1", + "opentelemetry-instrumentation-pymongo == 0.54b1", + "opentelemetry-instrumentation-pymysql == 0.54b1", + "opentelemetry-instrumentation-pyramid == 0.54b1", + "opentelemetry-instrumentation-redis == 0.54b1", + "opentelemetry-instrumentation-remoulade == 0.54b1", + "opentelemetry-instrumentation-requests == 0.54b1", + "opentelemetry-instrumentation-sqlalchemy == 0.54b1", + "opentelemetry-instrumentation-sqlite3 == 0.54b1", + "opentelemetry-instrumentation-starlette == 0.54b1", + "opentelemetry-instrumentation-system-metrics == 0.54b1", + "opentelemetry-instrumentation-tornado == 0.54b1", + "opentelemetry-instrumentation-tortoiseorm == 0.54b1", + "opentelemetry-instrumentation-urllib == 0.54b1", + "opentelemetry-instrumentation-urllib3 == 0.54b1", + "opentelemetry-instrumentation-wsgi == 0.54b1", + "opentelemetry-instrumentation-cassandra == 0.54b1", ] [project.optional-dependencies] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_attribute_keys.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_attribute_keys.py index 23ba661af..71e675cd3 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_attribute_keys.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_attribute_keys.py @@ -33,3 +33,4 @@ AWS_LAMBDA_FUNCTION_NAME: str = "aws.lambda.function.name" AWS_LAMBDA_RESOURCEMAPPING_ID: str = "aws.lambda.resource_mapping.id" AWS_LAMBDA_FUNCTION_ARN: str = "aws.lambda.function.arn" +AWS_SERVICE_TYPE: str = "aws.service.type" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py index ec5b693ed..173f8492b 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py @@ -35,7 +35,6 @@ ) from amazon.opentelemetry.distro._aws_resource_attribute_configurator import get_service_attribute from amazon.opentelemetry.distro._aws_span_processing_util import ( - GEN_AI_REQUEST_MODEL, LOCAL_ROOT, MAX_KEYWORD_LENGTH, SQL_KEYWORD_PATTERN, @@ -60,6 +59,7 @@ from amazon.opentelemetry.distro.sqs_url_parser import SqsUrlParser from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import BoundedAttributes, ReadableSpan +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_REQUEST_MODEL from opentelemetry.semconv.trace import SpanAttributes # Pertinent OTEL attribute keys diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_span_processing_util.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_span_processing_util.py index 21e19afa9..d2a039861 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_span_processing_util.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_span_processing_util.py @@ -26,16 +26,6 @@ # Max keyword length supported by parsing into remote_operation from DB_STATEMENT MAX_KEYWORD_LENGTH = 27 -# TODO: Use Semantic Conventions once upgrade to 0.47b0 -GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model" -GEN_AI_SYSTEM: str = "gen_ai.system" -GEN_AI_REQUEST_MAX_TOKENS: str = "gen_ai.request.max_tokens" -GEN_AI_REQUEST_TEMPERATURE: str = "gen_ai.request.temperature" -GEN_AI_REQUEST_TOP_P: str = "gen_ai.request.top_p" -GEN_AI_RESPONSE_FINISH_REASONS: str = "gen_ai.response.finish_reasons" -GEN_AI_USAGE_INPUT_TOKENS: str = "gen_ai.usage.input_tokens" -GEN_AI_USAGE_OUTPUT_TOKENS: str = "gen_ai.usage.output_tokens" - # Get dialect keywords retrieved from dialect_keywords.json file. # Only meant to be invoked by SQL_KEYWORD_PATTERN and unit tests diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_utils.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_utils.py index 149f9ad29..25c60d14a 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_utils.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_utils.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import os -import sys +from importlib.metadata import PackageNotFoundError, version from logging import Logger, getLogger +from typing import Optional -import pkg_resources +from packaging.requirements import Requirement _logger: Logger = getLogger(__name__) @@ -14,18 +15,60 @@ def is_installed(req: str) -> bool: """Is the given required package installed?""" - - if req in sys.modules and sys.modules[req] is not None: - return True + req = Requirement(req) try: - pkg_resources.get_distribution(req) - except Exception as exc: # pylint: disable=broad-except + dist_version = version(req.name) + except PackageNotFoundError as exc: _logger.debug("Skipping instrumentation patch: package %s, exception: %s", req, exc) return False + + if not list(req.specifier.filter([dist_version])): + _logger.debug( + "instrumentation for package %s is available but version %s is installed. Skipping.", + req, + dist_version, + ) + return False return True def is_agent_observability_enabled() -> bool: """Is the Agentic AI monitoring flag set to true?""" return os.environ.get(AGENT_OBSERVABILITY_ENABLED, "false").lower() == "true" + + +IS_BOTOCORE_INSTALLED: bool = is_installed("botocore") + + +def get_aws_session(): + """Returns a botocore session only if botocore is installed, otherwise None. + + We do this to prevent runtime errors for ADOT customers that do not need + any features that require botocore. + """ + if IS_BOTOCORE_INSTALLED: + # pylint: disable=import-outside-toplevel + from botocore.session import Session + + session = Session() + # Botocore only looks up AWS_DEFAULT_REGION when creating a session/client + # See: https://docs.aws.amazon.com/sdkref/latest/guide/feature-region.html#feature-region-sdk-compat + region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + if region: + session.set_config_variable("region", region) + return session + return None + + +def get_aws_region() -> Optional[str]: + """Get AWS region from environment or botocore session. + + Returns the AWS region in the following priority order: + 1. AWS_REGION environment variable + 2. AWS_DEFAULT_REGION environment variable + 3. botocore session's region (if botocore is available) + 4. None if no region can be determined + """ + botocore_session = get_aws_session() + return botocore_session.get_config_variable("region") if botocore_session else None diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py index b21bc6151..e8997885b 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_configurator.py @@ -1,17 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License. +import logging import os import re -from logging import NOTSET, Logger, getLogger -from typing import ClassVar, Dict, List, Optional, Type, Union +from logging import Logger, getLogger +from typing import ClassVar, Dict, List, NamedTuple, Optional, Type, Union from importlib_metadata import version from typing_extensions import override -from amazon.opentelemetry.distro._aws_attribute_keys import AWS_LOCAL_SERVICE +from amazon.opentelemetry.distro._aws_attribute_keys import AWS_LOCAL_SERVICE, AWS_SERVICE_TYPE from amazon.opentelemetry.distro._aws_resource_attribute_configurator import get_service_attribute -from amazon.opentelemetry.distro._utils import is_agent_observability_enabled +from amazon.opentelemetry.distro._utils import get_aws_session, is_agent_observability_enabled from amazon.opentelemetry.distro.always_record_sampler import AlwaysRecordSampler from amazon.opentelemetry.distro.attribute_propagating_span_processor_builder import ( AttributePropagatingSpanProcessorBuilder, @@ -22,13 +23,11 @@ AwsMetricAttributesSpanExporterBuilder, ) from amazon.opentelemetry.distro.aws_span_metrics_processor_builder import AwsSpanMetricsProcessorBuilder -from amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor import AwsBatchLogRecordProcessor -from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter -from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter from amazon.opentelemetry.distro.otlp_udp_exporter import OTLPUdpSpanExporter from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler from amazon.opentelemetry.distro.scope_based_exporter import ScopeBasedPeriodicExportingMetricReader from amazon.opentelemetry.distro.scope_based_filtering_view import ScopeBasedRetainingView +from opentelemetry._events import set_event_logger_provider from opentelemetry._logs import get_logger_provider, set_logger_provider from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPHttpOTLPMetricExporter @@ -43,7 +42,9 @@ _import_id_generator, _import_sampler, _OTelSDKConfigurator, + _patch_basic_config, ) +from opentelemetry.sdk._events import EventLoggerProvider from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogExporter from opentelemetry.sdk.environment_variables import ( @@ -99,13 +100,29 @@ AWS_OTLP_LOGS_GROUP_HEADER = "x-aws-log-group" AWS_OTLP_LOGS_STREAM_HEADER = "x-aws-log-stream" +AWS_EMF_METRICS_NAMESPACE = "x-aws-metric-namespace" # UDP package size is not larger than 64KB LAMBDA_SPAN_EXPORT_BATCH_SIZE = 10 +OTEL_TRACES_EXPORTER = "OTEL_TRACES_EXPORTER" +OTEL_LOGS_EXPORTER = "OTEL_LOGS_EXPORTER" +OTEL_METRICS_EXPORTER = "OTEL_METRICS_EXPORTER" +OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" +OTEL_TRACES_SAMPLER = "OTEL_TRACES_SAMPLER" +OTEL_PYTHON_DISABLED_INSTRUMENTATIONS = "OTEL_PYTHON_DISABLED_INSTRUMENTATIONS" +OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED = "OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED" + _logger: Logger = getLogger(__name__) +class OtlpLogHeaderSetting(NamedTuple): + log_group: Optional[str] + log_stream: Optional[str] + namespace: Optional[str] + is_valid: bool + + class AwsOpenTelemetryConfigurator(_OTelSDKConfigurator): """ This AwsOpenTelemetryConfigurator extend _OTelSDKConfigurator configuration with the following change: @@ -134,6 +151,11 @@ def _configure(self, **kwargs): # Long term, we wish to contribute this to upstream to improve initialization customizability and reduce dependency on # internal logic. def _initialize_components(): + # Remove 'awsemf' from OTEL_METRICS_EXPORTER if present to prevent validation errors + # from _import_exporters in OTel dependencies which would try to load exporters + # We will contribute emf exporter to upstream for supporting OTel metrics in SDK + is_emf_enabled = _check_emf_exporter_enabled() + trace_exporters, metric_exporters, log_exporters = _import_exporters( _get_exporter_names("traces"), _get_exporter_names("metrics"), @@ -154,7 +176,7 @@ def _initialize_components(): AwsEksResourceDetector(), AwsEcsResourceDetector(), ] - if not _is_lambda_environment() + if not (_is_lambda_environment() or is_agent_observability_enabled()) else [] ) @@ -169,39 +191,36 @@ def _initialize_components(): sampler=sampler, resource=resource, ) - _init_metrics(metric_exporters, resource) + + _init_metrics(metric_exporters, resource, is_emf_enabled) logging_enabled = os.getenv(_OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED, "false") if logging_enabled.strip().lower() == "true": _init_logging(log_exporters, resource) def _init_logging( - exporters: Dict[str, Type[LogExporter]], - resource: Resource = None, + exporters: dict[str, Type[LogExporter]], + resource: Optional[Resource] = None, + setup_logging_handler: bool = True, ): - - # Provides a default OTLP log exporter when none is specified. - # This is the behavior for the logs exporters for other languages. - logs_exporter = os.environ.get("OTEL_LOGS_EXPORTER") - - if not exporters and logs_exporter and (logs_exporter.lower() != "none"): - exporters = {"otlp": OTLPLogExporter} - provider = LoggerProvider(resource=resource) set_logger_provider(provider) for _, exporter_class in exporters.items(): - exporter_args: Dict[str, any] = {} - log_exporter = _customize_logs_exporter(exporter_class(**exporter_args), resource) + exporter_args = {} + _customize_log_record_processor( + logger_provider=provider, log_exporter=_customize_logs_exporter(exporter_class(**exporter_args)) + ) - if isinstance(log_exporter, OTLPAwsLogExporter) and is_agent_observability_enabled(): - provider.add_log_record_processor(AwsBatchLogRecordProcessor(exporter=log_exporter)) - else: - provider.add_log_record_processor(BatchLogRecordProcessor(exporter=log_exporter)) + event_logger_provider = EventLoggerProvider(logger_provider=provider) + set_event_logger_provider(event_logger_provider) - handler = LoggingHandler(level=NOTSET, logger_provider=provider) + if setup_logging_handler: + _patch_basic_config() - getLogger().addHandler(handler) + # Add OTel handler + handler = LoggingHandler(level=logging.NOTSET, logger_provider=provider) + logging.getLogger().addHandler(handler) def _init_tracing( @@ -234,6 +253,7 @@ def _init_tracing( def _init_metrics( exporters_or_readers: Dict[str, Union[Type[MetricExporter], Type[MetricReader]]], resource: Resource = None, + is_emf_enabled: bool = False, ): metric_readers = [] views = [] @@ -246,7 +266,7 @@ def _init_metrics( else: metric_readers.append(PeriodicExportingMetricReader(exporter_or_reader_class(**exporter_args))) - _customize_metric_exporters(metric_readers, views) + _customize_metric_exporters(metric_readers, views, is_emf_enabled) provider = MeterProvider(resource=resource, metric_readers=metric_readers, views=views) set_meter_provider(provider) @@ -272,6 +292,19 @@ def _export_unsampled_span_for_lambda(trace_provider: TracerProvider, resource: ) +def _export_unsampled_span_for_agent_observability(trace_provider: TracerProvider, resource: Resource = None): + if not is_agent_observability_enabled(): + return + + traces_endpoint = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT) + if traces_endpoint and _is_aws_otlp_endpoint(traces_endpoint): + endpoint = traces_endpoint.lower() + region = endpoint.split(".")[1] + + span_exporter = _create_aws_otlp_exporter(endpoint=endpoint, service="xray", region=region) + trace_provider.add_span_processor(BatchUnsampledSpanProcessor(span_exporter=span_exporter)) + + def _is_defer_to_workers_enabled(): return os.environ.get(OTEL_AWS_PYTHON_DEFER_TO_WORKERS_ENABLED_CONFIG, "false").strip().lower() == "true" @@ -364,25 +397,18 @@ def _customize_span_exporter(span_exporter: SpanExporter, resource: Resource) -> traces_endpoint = os.environ.get(AWS_XRAY_DAEMON_ADDRESS_CONFIG, "127.0.0.1:2000") span_exporter = OTLPUdpSpanExporter(endpoint=traces_endpoint) - if _is_aws_otlp_endpoint(traces_endpoint, "xray"): + if traces_endpoint and _is_aws_otlp_endpoint(traces_endpoint, "xray"): _logger.info("Detected using AWS OTLP Traces Endpoint.") if isinstance(span_exporter, OTLPSpanExporter): - if is_agent_observability_enabled(): - # Span exporter needs an instance of logger provider in ai agent - # observability case because we need to split input/output prompts - # from span attributes and send them to the logs pipeline per - # the new Gen AI semantic convention from OTel - # ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-events/ - span_exporter = OTLPAwsSpanExporter(endpoint=traces_endpoint, logger_provider=get_logger_provider()) - else: - span_exporter = OTLPAwsSpanExporter(endpoint=traces_endpoint) + endpoint = traces_endpoint.lower() + region = endpoint.split(".")[1] + return _create_aws_otlp_exporter(endpoint=traces_endpoint, service="xray", region=region) - else: - _logger.warning( - "Improper configuration see: please export/set " - "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf and OTEL_TRACES_EXPORTER=otlp" - ) + _logger.warning( + "Improper configuration see: please export/set " + "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf and OTEL_TRACES_EXPORTER=otlp" + ) if not _is_application_signals_enabled(): return span_exporter @@ -390,17 +416,35 @@ def _customize_span_exporter(span_exporter: SpanExporter, resource: Resource) -> return AwsMetricAttributesSpanExporterBuilder(span_exporter, resource).build() -def _customize_logs_exporter(log_exporter: LogExporter, resource: Resource) -> LogExporter: +def _customize_log_record_processor(logger_provider: LoggerProvider, log_exporter: Optional[LogExporter]) -> None: + if not log_exporter: + return + + if is_agent_observability_enabled(): + # pylint: disable=import-outside-toplevel + from amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor import ( + AwsCloudWatchOtlpBatchLogRecordProcessor, + ) + + logger_provider.add_log_record_processor(AwsCloudWatchOtlpBatchLogRecordProcessor(exporter=log_exporter)) + else: + logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter=log_exporter)) + + +def _customize_logs_exporter(log_exporter: LogExporter) -> LogExporter: logs_endpoint = os.environ.get(OTEL_EXPORTER_OTLP_LOGS_ENDPOINT) - if _is_aws_otlp_endpoint(logs_endpoint, "logs"): + if logs_endpoint and _is_aws_otlp_endpoint(logs_endpoint, "logs"): + _logger.info("Detected using AWS OTLP Logs Endpoint.") - if isinstance(log_exporter, OTLPLogExporter) and _validate_logs_headers(): + if isinstance(log_exporter, OTLPLogExporter) and _validate_and_fetch_logs_header().is_valid: + endpoint = logs_endpoint.lower() + region = endpoint.split(".")[1] # Setting default compression mode to Gzip as this is the behavior in upstream's # collector otlp http exporter: # https://github.com/open-telemetry/opentelemetry-collector/tree/main/exporter/otlphttpexporter - return OTLPAwsLogExporter(endpoint=logs_endpoint) + return _create_aws_otlp_exporter(endpoint=logs_endpoint, service="logs", region=region) _logger.warning( "Improper configuration see: please export/set " @@ -415,9 +459,14 @@ def _customize_span_processors(provider: TracerProvider, resource: Resource) -> if _is_lambda_environment(): provider.add_span_processor(AwsLambdaSpanProcessor()) + # We always send 100% spans to Genesis platform for agent observability because + # AI applications typically have low throughput traffic patterns and require + # comprehensive monitoring to catch subtle failure modes like hallucinations + # and quality degradation that sampling could miss. # Add session.id baggage attribute to span attributes to support AI Agent use cases # enabling session ID tracking in spans. if is_agent_observability_enabled(): + _export_unsampled_span_for_agent_observability(provider, resource) def session_id_predicate(baggage_key: str) -> bool: return baggage_key == "session.id" @@ -449,7 +498,9 @@ def session_id_predicate(baggage_key: str) -> bool: return -def _customize_metric_exporters(metric_readers: List[MetricReader], views: List[View]) -> None: +def _customize_metric_exporters( + metric_readers: List[MetricReader], views: List[View], is_emf_enabled: bool = False +) -> None: if _is_application_signals_runtime_enabled(): _get_runtime_metric_views(views, 0 == len(metric_readers)) @@ -461,6 +512,11 @@ def _customize_metric_exporters(metric_readers: List[MetricReader], views: List[ ) metric_readers.append(scope_based_periodic_exporting_metric_reader) + if is_emf_enabled: + emf_exporter = create_emf_exporter() + if emf_exporter: + metric_readers.append(PeriodicExportingMetricReader(emf_exporter)) + def _get_runtime_metric_views(views: List[View], retain_runtime_only: bool) -> None: runtime_metrics_scope_name = SYSTEM_METRICS_INSTRUMENTATION_SCOPE_NAME @@ -516,7 +572,15 @@ def _customize_resource(resource: Resource) -> Resource: if is_unknown: _logger.debug("No valid service name found") - return resource.merge(Resource.create({AWS_LOCAL_SERVICE: service_name})) + custom_attributes = {AWS_LOCAL_SERVICE: service_name} + + if is_agent_observability_enabled(): + # Add aws.service.type if it doesn't exist in the resource + if resource and resource.attributes.get(AWS_SERVICE_TYPE) is None: + # Set a default agent type for AI agent observability + custom_attributes[AWS_SERVICE_TYPE] = "gen_ai_agent" + + return resource.merge(Resource.create(custom_attributes)) def _is_application_signals_enabled(): @@ -542,15 +606,15 @@ def _is_lambda_environment(): def _is_aws_otlp_endpoint(otlp_endpoint: Optional[str] = None, service: str = "xray") -> bool: """Is the given endpoint an AWS OTLP endpoint?""" - pattern = AWS_TRACES_OTLP_ENDPOINT_PATTERN if service == "xray" else AWS_LOGS_OTLP_ENDPOINT_PATTERN - if not otlp_endpoint: return False + pattern = AWS_TRACES_OTLP_ENDPOINT_PATTERN if service == "xray" else AWS_LOGS_OTLP_ENDPOINT_PATTERN + return bool(re.match(pattern, otlp_endpoint.lower())) -def _validate_logs_headers() -> bool: +def _validate_and_fetch_logs_header() -> OtlpLogHeaderSetting: """Checks if x-aws-log-group and x-aws-log-stream are present in the headers in order to send logs to AWS OTLP Logs endpoint.""" @@ -561,8 +625,11 @@ def _validate_logs_headers() -> bool: "Improper configuration: Please configure the environment variable OTEL_EXPORTER_OTLP_LOGS_HEADERS " "to include x-aws-log-group and x-aws-log-stream" ) - return False + return OtlpLogHeaderSetting(None, None, None, False) + log_group = None + log_stream = None + namespace = None filtered_log_headers_count = 0 for pair in logs_headers.split(","): @@ -570,17 +637,24 @@ def _validate_logs_headers() -> bool: split = pair.split("=", 1) key = split[0] value = split[1] - if key in (AWS_OTLP_LOGS_GROUP_HEADER, AWS_OTLP_LOGS_STREAM_HEADER) and value: + if key == AWS_OTLP_LOGS_GROUP_HEADER and value: + log_group = value + filtered_log_headers_count += 1 + elif key == AWS_OTLP_LOGS_STREAM_HEADER and value: + log_stream = value filtered_log_headers_count += 1 + elif key == AWS_EMF_METRICS_NAMESPACE and value: + namespace = value - if filtered_log_headers_count != 2: + is_valid = filtered_log_headers_count == 2 and log_group is not None and log_stream is not None + + if not is_valid: _logger.warning( "Improper configuration: Please configure the environment variable OTEL_EXPORTER_OTLP_LOGS_HEADERS " "to have values for x-aws-log-group and x-aws-log-stream" ) - return False - return True + return OtlpLogHeaderSetting(log_group, log_stream, namespace, is_valid) def _get_metric_export_interval(): @@ -651,3 +725,111 @@ def create_exporter(self): ) raise RuntimeError(f"Unsupported AWS Application Signals export protocol: {protocol} ") + + +def _check_emf_exporter_enabled() -> bool: + """ + Checks if OTEL_METRICS_EXPORTER contains "awsemf", removes it if present, + and updates the environment variable. + + Remove 'awsemf' from OTEL_METRICS_EXPORTER if present to prevent validation errors + from _import_exporters in OTel dependencies which would try to load exporters + We will contribute emf exporter to upstream for supporting OTel metrics in SDK + + Returns: + bool: True if "awsemf" was found and removed, False otherwise. + """ + # Get the current exporter value + exporter_value = os.environ.get("OTEL_METRICS_EXPORTER", "") + + # Check if it's empty + if not exporter_value: + return False + + # Split by comma and convert to list + exporters = [exp.strip() for exp in exporter_value.split(",")] + + # Check if awsemf is in the list + if "awsemf" not in exporters: + return False + + # Remove awsemf from the list + exporters.remove("awsemf") + + # Join the remaining exporters and update the environment variable + new_value = ",".join(exporters) if exporters else "" + + # Set the new value (or unset if empty) + if new_value: + os.environ["OTEL_METRICS_EXPORTER"] = new_value + elif "OTEL_METRICS_EXPORTER" in os.environ: + del os.environ["OTEL_METRICS_EXPORTER"] + + return True + + +def create_emf_exporter(): + """Create and configure the CloudWatch EMF exporter.""" + try: + session = get_aws_session() + # Check if botocore is available before importing the EMF exporter + if not session: + _logger.warning("botocore is not installed. EMF exporter requires botocore") + return None + + # pylint: disable=import-outside-toplevel + from amazon.opentelemetry.distro.exporter.aws.metrics.aws_cloudwatch_emf_exporter import ( + AwsCloudWatchEmfExporter, + ) + + log_header_setting = _validate_and_fetch_logs_header() + + if not log_header_setting.is_valid: + return None + + return AwsCloudWatchEmfExporter( + session=session, + namespace=log_header_setting.namespace, + log_group_name=log_header_setting.log_group, + log_stream_name=log_header_setting.log_stream, + ) + # pylint: disable=broad-exception-caught + except Exception as errors: + _logger.error("Failed to create EMF exporter: %s", errors) + return None + + +def _create_aws_otlp_exporter(endpoint: str, service: str, region: str): + """Create and configure the AWS OTLP exporters.""" + try: + session = get_aws_session() + # Check if botocore is available before importing the AWS exporter + if not session: + _logger.warning("SigV4 Auth requires botocore to be enabled") + return None + + # pylint: disable=import-outside-toplevel + from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter + from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter + + if service == "xray": + if is_agent_observability_enabled(): + # Span exporter needs an instance of logger provider in ai agent + # observability case because we need to split input/output prompts + # from span attributes and send them to the logs pipeline per + # the new Gen AI semantic convention from OTel + # ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-events/ + return OTLPAwsSpanExporter( + session=session, endpoint=endpoint, aws_region=region, logger_provider=get_logger_provider() + ) + + return OTLPAwsSpanExporter(session=session, endpoint=endpoint, aws_region=region) + + if service == "logs": + return OTLPAwsLogExporter(session=session, aws_region=region) + + return None + # pylint: disable=broad-exception-caught + except Exception as errors: + _logger.error("Failed to create AWS OTLP exporter: %s", errors) + return None diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_distro.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_distro.py index 9bca8acd1..a7f73f4e1 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_distro.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/aws_opentelemetry_distro.py @@ -4,6 +4,19 @@ import sys from logging import Logger, getLogger +from amazon.opentelemetry.distro._utils import get_aws_region, is_agent_observability_enabled +from amazon.opentelemetry.distro.aws_opentelemetry_configurator import ( + APPLICATION_SIGNALS_ENABLED_CONFIG, + OTEL_EXPORTER_OTLP_LOGS_ENDPOINT, + OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, + OTEL_LOGS_EXPORTER, + OTEL_METRICS_EXPORTER, + OTEL_PYTHON_DISABLED_INSTRUMENTATIONS, + OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED, + OTEL_TRACES_EXPORTER, + OTEL_TRACES_SAMPLER, +) from amazon.opentelemetry.distro.patches._instrumentation_patch import apply_instrumentation_patches from opentelemetry.distro import OpenTelemetryDistro from opentelemetry.environment_variables import OTEL_PROPAGATORS, OTEL_PYTHON_ID_GENERATOR @@ -57,13 +70,53 @@ def _configure(self, **kwargs): os.environ.setdefault(OTEL_EXPORTER_OTLP_PROTOCOL, "http/protobuf") - super(AwsOpenTelemetryDistro, self)._configure() - os.environ.setdefault(OTEL_PROPAGATORS, "xray,tracecontext,b3,b3multi") os.environ.setdefault(OTEL_PYTHON_ID_GENERATOR, "xray") os.environ.setdefault( OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION, "base2_exponential_bucket_histogram" ) + if is_agent_observability_enabled(): + # "otlp" is already native OTel default, but we set them here to be explicit + # about intended configuration for agent observability + os.environ.setdefault(OTEL_TRACES_EXPORTER, "otlp") + os.environ.setdefault(OTEL_LOGS_EXPORTER, "otlp") + os.environ.setdefault(OTEL_METRICS_EXPORTER, "awsemf") + + # Set GenAI capture content default + os.environ.setdefault(OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "true") + + region = get_aws_region() + + # Set OTLP endpoints with AWS region if not already set + if region: + os.environ.setdefault( + OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, f"https://xray.{region}.amazonaws.com/v1/traces" + ) + os.environ.setdefault(OTEL_EXPORTER_OTLP_LOGS_ENDPOINT, f"https://logs.{region}.amazonaws.com/v1/logs") + else: + _logger.warning( + "AWS region could not be determined. OTLP endpoints will not be automatically configured. " + "Please set AWS_REGION environment variable or configure OTLP endpoints manually." + ) + + # Set sampler default + os.environ.setdefault(OTEL_TRACES_SAMPLER, "parentbased_always_on") + + # Set disabled instrumentations default + os.environ.setdefault( + OTEL_PYTHON_DISABLED_INSTRUMENTATIONS, + "http,sqlalchemy,psycopg2,pymysql,sqlite3,aiopg,asyncpg,mysql_connector," + "botocore,boto3,urllib3,requests,starlette", + ) + + # Set logging auto instrumentation default + os.environ.setdefault(OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED, "true") + + # Disable AWS Application Signals by default + os.environ.setdefault(APPLICATION_SIGNALS_ENABLED_CONFIG, "false") + + super(AwsOpenTelemetryDistro, self)._configure() + if kwargs.get("apply_patches", True): apply_instrumentation_patches() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/_cloudwatch_log_client.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/_cloudwatch_log_client.py new file mode 100644 index 000000000..b7daac12b --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/_cloudwatch_log_client.py @@ -0,0 +1,380 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +import logging +import time +import uuid +from typing import Any, Dict, List, Optional + +from botocore.exceptions import ClientError +from botocore.session import Session + +logger = logging.getLogger(__name__) + + +class LogEventBatch: + """ + Container for a batch of CloudWatch log events with metadata. + + Tracks the log events, total byte size, and timestamps + for efficient batching and validation. + """ + + def __init__(self): + """Initialize an empty log event batch.""" + self.log_events: List[Dict[str, Any]] = [] + self.byte_total: int = 0 + self.min_timestamp_ms: int = 0 + self.max_timestamp_ms: int = 0 + self.created_timestamp_ms: int = int(time.time() * 1000) + + def add_event(self, log_event: Dict[str, Any], event_size: int) -> None: + """ + Add a log event to the batch. + + Args: + log_event: The log event to add + event_size: The byte size of the event + """ + self.log_events.append(log_event) + self.byte_total += event_size + + # Update timestamp tracking + timestamp = log_event.get("timestamp", 0) + if self.min_timestamp_ms == 0 or timestamp < self.min_timestamp_ms: + self.min_timestamp_ms = timestamp + if timestamp > self.max_timestamp_ms: + self.max_timestamp_ms = timestamp + + def is_empty(self) -> bool: + """Check if the batch is empty.""" + return len(self.log_events) == 0 + + def size(self) -> int: + """Get the number of events in the batch.""" + return len(self.log_events) + + def clear(self) -> None: + """Clear the batch.""" + self.log_events.clear() + self.byte_total = 0 + self.min_timestamp_ms = 0 + self.max_timestamp_ms = 0 + self.created_timestamp_ms = int(time.time() * 1000) + + +class CloudWatchLogClient: + """ + CloudWatch Logs client for batching and sending log events. + + This class handles the batching logic and CloudWatch Logs API interactions + for sending EMF logs efficiently while respecting CloudWatch Logs constraints. + """ + + # Constants for CloudWatch Logs limits + # http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/cloudwatch_limits_cwl.html + # http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html + CW_MAX_EVENT_PAYLOAD_BYTES = 256 * 1024 # 256KB + CW_MAX_REQUEST_EVENT_COUNT = 10000 + CW_PER_EVENT_HEADER_BYTES = 26 + BATCH_FLUSH_INTERVAL = 60 * 1000 + CW_MAX_REQUEST_PAYLOAD_BYTES = 1 * 1024 * 1024 # 1MB + CW_TRUNCATED_SUFFIX = "[Truncated...]" + # None of the log events in the batch can be older than 14 days + CW_EVENT_TIMESTAMP_LIMIT_PAST = 14 * 24 * 60 * 60 * 1000 + # None of the log events in the batch can be more than 2 hours in the future. + CW_EVENT_TIMESTAMP_LIMIT_FUTURE = 2 * 60 * 60 * 1000 + + def __init__( + self, + log_group_name: str, + session: Session, + log_stream_name: Optional[str] = None, + aws_region: Optional[str] = None, + **kwargs, + ): + """ + Initialize the CloudWatch Logs client. + + Args: + log_group_name: CloudWatch log group name + log_stream_name: CloudWatch log stream name (auto-generated if None) + aws_region: AWS region (auto-detected if None) + **kwargs: Additional arguments passed to botocore client + """ + self.log_group_name = log_group_name + self.log_stream_name = log_stream_name or self._generate_log_stream_name() + self.logs_client = session.create_client("logs", region_name=aws_region, **kwargs) + + # Event batch to store logs before sending to CloudWatch + self._event_batch = None + + def _generate_log_stream_name(self) -> str: + """Generate a unique log stream name.""" + unique_id = str(uuid.uuid4())[:8] + return f"otel-python-{unique_id}" + + def _create_log_group_if_needed(self): + """Create log group if it doesn't exist.""" + try: + self.logs_client.create_log_group(logGroupName=self.log_group_name) + logger.info("Created log group: %s", self.log_group_name) + except ClientError as error: + if error.response.get("Error", {}).get("Code") == "ResourceAlreadyExistsException": + logger.debug("Log group %s already exists", self.log_group_name) + else: + logger.error("Failed to create log group %s : %s", self.log_group_name, error) + raise + + def _create_log_stream_if_needed(self): + """Create log stream if it doesn't exist.""" + try: + self.logs_client.create_log_stream(logGroupName=self.log_group_name, logStreamName=self.log_stream_name) + logger.info("Created log stream: %s", self.log_stream_name) + except ClientError as error: + if error.response.get("Error", {}).get("Code") == "ResourceAlreadyExistsException": + logger.debug("Log stream %s already exists", self.log_stream_name) + else: + logger.error("Failed to create log stream %s : %s", self.log_stream_name, error) + raise + + def _validate_log_event(self, log_event: Dict) -> bool: + """ + Validate the log event according to CloudWatch Logs constraints. + Implements the same validation logic as the Go version. + + Args: + log_event: The log event to validate + + Returns: + bool: True if valid, False otherwise + """ + + # Check empty message + if not log_event.get("message") or not log_event.get("message").strip(): + logger.error("Empty log event message") + return False + + message = log_event.get("message", "") + timestamp = log_event.get("timestamp", 0) + + # Check message size + message_size = len(message) + self.CW_PER_EVENT_HEADER_BYTES + if message_size > self.CW_MAX_EVENT_PAYLOAD_BYTES: + logger.warning( + "Log event size %s exceeds maximum allowed size %s. Truncating.", + message_size, + self.CW_MAX_EVENT_PAYLOAD_BYTES, + ) + max_message_size = ( + self.CW_MAX_EVENT_PAYLOAD_BYTES - self.CW_PER_EVENT_HEADER_BYTES - len(self.CW_TRUNCATED_SUFFIX) + ) + log_event["message"] = message[:max_message_size] + self.CW_TRUNCATED_SUFFIX + + # Check timestamp constraints + current_time = int(time.time() * 1000) # Current time in milliseconds + event_time = timestamp + + # Calculate the time difference + time_diff = current_time - event_time + + # Check if too old or too far in the future + if time_diff > self.CW_EVENT_TIMESTAMP_LIMIT_PAST or time_diff < -self.CW_EVENT_TIMESTAMP_LIMIT_FUTURE: + logger.error( + "Log event timestamp %s is either older than 14 days or more than 2 hours in the future. " + "Current time: %s", + event_time, + current_time, + ) + return False + + return True + + def _create_event_batch(self) -> LogEventBatch: + """ + Create a new log event batch. + + Returns: + LogEventBatch: A new event batch + """ + return LogEventBatch() + + def _event_batch_exceeds_limit(self, batch: LogEventBatch, next_event_size: int) -> bool: + """ + Check if adding the next event would exceed CloudWatch Logs limits. + + Args: + batch: The current batch + next_event_size: Size of the next event in bytes + + Returns: + bool: True if adding the next event would exceed limits + """ + return ( + batch.size() >= self.CW_MAX_REQUEST_EVENT_COUNT + or batch.byte_total + next_event_size > self.CW_MAX_REQUEST_PAYLOAD_BYTES + ) + + def _is_batch_active(self, batch: LogEventBatch, target_timestamp_ms: int) -> bool: + """ + Check if the event batch spans more than 24 hours. + + Args: + batch: The event batch + target_timestamp_ms: The timestamp of the event to add + + Returns: + bool: True if the batch is active and can accept the event + """ + # New log event batch + if batch.min_timestamp_ms == 0 or batch.max_timestamp_ms == 0: + return True + + # Check if adding the event would make the batch span more than 24 hours + if target_timestamp_ms - batch.min_timestamp_ms > 24 * 3600 * 1000: + return False + + if batch.max_timestamp_ms - target_timestamp_ms > 24 * 3600 * 1000: + return False + + # flush the event batch when reached 60s interval + current_time = int(time.time() * 1000) + if current_time - batch.created_timestamp_ms >= self.BATCH_FLUSH_INTERVAL: + return False + + return True + + def _sort_log_events(self, batch: LogEventBatch) -> None: + """ + Sort log events in the batch by timestamp. + + Args: + batch: The event batch + """ + batch.log_events = sorted(batch.log_events, key=lambda x: x["timestamp"]) + + def _send_log_batch(self, batch: LogEventBatch) -> None: + """ + Send a batch of log events to CloudWatch Logs. + Creates log group and stream lazily if they don't exist. + + Args: + batch: The event batch + """ + if batch.is_empty(): + return None + + # Sort log events by timestamp + self._sort_log_events(batch) + + # Prepare the PutLogEvents request + put_log_events_input = { + "logGroupName": self.log_group_name, + "logStreamName": self.log_stream_name, + "logEvents": batch.log_events, + } + + start_time = time.time() + + try: + # Make the PutLogEvents call + response = self.logs_client.put_log_events(**put_log_events_input) + + elapsed_ms = int((time.time() - start_time) * 1000) + logger.debug( + "Successfully sent %s log events (%s KB) in %s ms", + batch.size(), + batch.byte_total / 1024, + elapsed_ms, + ) + + return response + + except ClientError as error: + # Handle resource not found errors by creating log group/stream + error_code = error.response.get("Error", {}).get("Code") + if error_code == "ResourceNotFoundException": + logger.info("Log group or stream not found, creating resources and retrying") + + try: + # Create log group first + self._create_log_group_if_needed() + # Then create log stream + self._create_log_stream_if_needed() + + # Retry the PutLogEvents call + response = self.logs_client.put_log_events(**put_log_events_input) + + elapsed_ms = int((time.time() - start_time) * 1000) + logger.debug( + "Successfully sent %s log events (%s KB) in %s ms after creating resources", + batch.size(), + batch.byte_total / 1024, + elapsed_ms, + ) + + return response + + except ClientError as retry_error: + logger.error("Failed to send log events after creating resources: %s", retry_error) + raise + else: + logger.error("Failed to send log events: %s", error) + raise + + def send_log_event(self, log_event: Dict[str, Any]): + """ + Send a log event to CloudWatch Logs. + + This function implements the same logic as the Go version in the OTel Collector. + It batches log events according to CloudWatch Logs constraints and sends them + when the batch is full or spans more than 24 hours. + + Args: + log_event: The log event to send + """ + try: + # Validate the log event + if not self._validate_log_event(log_event): + return + + # Calculate event size + event_size = len(log_event["message"]) + self.CW_PER_EVENT_HEADER_BYTES + + # Initialize event batch if needed + if self._event_batch is None: + self._event_batch = self._create_event_batch() + + # Check if we need to send the current batch and create a new one + current_batch = self._event_batch + if self._event_batch_exceeds_limit(current_batch, event_size) or not self._is_batch_active( + current_batch, log_event["timestamp"] + ): + # Send the current batch + self._send_log_batch(current_batch) + # Create a new batch + self._event_batch = self._create_event_batch() + current_batch = self._event_batch + + # Add the log event to the batch + current_batch.add_event(log_event, event_size) + + except Exception as error: + logger.error("Failed to process log event: %s", error) + raise + + def flush_pending_events(self) -> bool: + """ + Flush any pending log events. + + Returns: + True if successful, False otherwise + """ + if self._event_batch is not None and not self._event_batch.is_empty(): + current_batch = self._event_batch + self._send_log_batch(current_batch) + self._event_batch = self._create_event_batch() + logger.debug("CloudWatchLogClient flushed the buffered log events") + return True diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/aws_cloudwatch_emf_exporter.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/aws_cloudwatch_emf_exporter.py new file mode 100644 index 000000000..643897da0 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/aws/metrics/aws_cloudwatch_emf_exporter.py @@ -0,0 +1,631 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=no-self-use + +import json +import logging +import math +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +from opentelemetry.sdk.metrics import Counter +from opentelemetry.sdk.metrics import Histogram as HistogramInstr +from opentelemetry.sdk.metrics import ObservableCounter, ObservableGauge, ObservableUpDownCounter, UpDownCounter +from opentelemetry.sdk.metrics._internal.point import Metric +from opentelemetry.sdk.metrics.export import ( + AggregationTemporality, + ExponentialHistogram, + Gauge, + Histogram, + MetricExporter, + MetricExportResult, + MetricsData, + NumberDataPoint, + Sum, +) +from opentelemetry.sdk.metrics.view import ExponentialBucketHistogramAggregation +from opentelemetry.sdk.resources import Resource +from opentelemetry.util.types import Attributes + +from ._cloudwatch_log_client import CloudWatchLogClient + +logger = logging.getLogger(__name__) + + +class MetricRecord: + """The metric data unified representation of all OTel metrics for OTel to CW EMF conversion.""" + + def __init__(self, metric_name: str, metric_unit: str, metric_description: str): + """ + Initialize metric record. + + Args: + metric_name: Name of the metric + metric_unit: Unit of the metric + metric_description: Description of the metric + """ + # Instrument metadata + self.name = metric_name + self.unit = metric_unit + self.description = metric_description + + # Will be set by conversion methods + self.timestamp: Optional[int] = None + self.attributes: Attributes = {} + + # Different metric type data - only one will be set per record + self.value: Optional[float] = None + self.sum_data: Optional[Any] = None + self.histogram_data: Optional[Any] = None + self.exp_histogram_data: Optional[Any] = None + + +class AwsCloudWatchEmfExporter(MetricExporter): + """ + OpenTelemetry metrics exporter for CloudWatch EMF format. + + This exporter converts OTel metrics into CloudWatch EMF logs which are then + sent to CloudWatch Logs. CloudWatch Logs automatically extracts the metrics + from the EMF logs. + + https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch_Embedded_Metric_Format_Specification.html + + """ + + # CloudWatch EMF supported units + # Ref: https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_MetricDatum.html + EMF_SUPPORTED_UNITS = { + "Seconds", + "Microseconds", + "Milliseconds", + "Bytes", + "Kilobytes", + "Megabytes", + "Gigabytes", + "Terabytes", + "Bits", + "Kilobits", + "Megabits", + "Gigabits", + "Terabits", + "Percent", + "Count", + "Bytes/Second", + "Kilobytes/Second", + "Megabytes/Second", + "Gigabytes/Second", + "Terabytes/Second", + "Bits/Second", + "Kilobits/Second", + "Megabits/Second", + "Gigabits/Second", + "Terabits/Second", + "Count/Second", + "None", + } + + # OTel to CloudWatch unit mapping + # Ref: opentelemetry-collector-contrib/blob/main/exporter/awsemfexporter/grouped_metric.go#L188 + UNIT_MAPPING = { + "1": "", + "ns": "", + "ms": "Milliseconds", + "s": "Seconds", + "us": "Microseconds", + "By": "Bytes", + "bit": "Bits", + } + + def __init__( + self, + namespace: str = "default", + log_group_name: str = None, + log_stream_name: Optional[str] = None, + aws_region: Optional[str] = None, + preferred_temporality: Optional[Dict[type, AggregationTemporality]] = None, + preferred_aggregation: Optional[Dict[type, Any]] = None, + **kwargs, + ): + """ + Initialize the CloudWatch EMF exporter. + + Args: + namespace: CloudWatch namespace for metrics + log_group_name: CloudWatch log group name + log_stream_name: CloudWatch log stream name (auto-generated if None) + aws_region: AWS region (auto-detected if None) + preferred_temporality: Optional dictionary mapping instrument types to aggregation temporality + preferred_aggregation: Optional dictionary mapping instrument types to preferred aggregation + **kwargs: Additional arguments passed to botocore client + """ + # Set up temporality preference default to DELTA if customers not set + if preferred_temporality is None: + preferred_temporality = { + Counter: AggregationTemporality.DELTA, + HistogramInstr: AggregationTemporality.DELTA, + ObservableCounter: AggregationTemporality.DELTA, + ObservableGauge: AggregationTemporality.DELTA, + ObservableUpDownCounter: AggregationTemporality.DELTA, + UpDownCounter: AggregationTemporality.DELTA, + } + + # Set up aggregation preference default to exponential histogram for histogram metrics + if preferred_aggregation is None: + preferred_aggregation = { + HistogramInstr: ExponentialBucketHistogramAggregation(), + } + + super().__init__(preferred_temporality, preferred_aggregation) + + self.namespace = namespace + self.log_group_name = log_group_name + + # Initialize CloudWatch Logs client + self.log_client = CloudWatchLogClient( + log_group_name=log_group_name, log_stream_name=log_stream_name, aws_region=aws_region, **kwargs + ) + + def _get_metric_name(self, record: MetricRecord) -> Optional[str]: + """Get the metric name from the metric record or data point.""" + + try: + if record.name: + return record.name + except AttributeError: + pass + # Return None if no valid metric name found + return None + + def _get_unit(self, record: MetricRecord) -> Optional[str]: + """Get CloudWatch unit from MetricRecord unit.""" + unit = record.unit + + if not unit: + return None + + # First check if unit is already a supported EMF unit + if unit in self.EMF_SUPPORTED_UNITS: + return unit + + # Map from OTel unit to CloudWatch unit + mapped_unit = self.UNIT_MAPPING.get(unit) + + return mapped_unit + + def _get_dimension_names(self, attributes: Attributes) -> List[str]: + """Extract dimension names from attributes.""" + # Implement dimension selection logic + # For now, use all attributes as dimensions + return list(attributes.keys()) + + def _get_attributes_key(self, attributes: Attributes) -> str: + """ + Create a hashable key from attributes for grouping metrics. + + Args: + attributes: The attributes dictionary + + Returns: + A string representation of sorted attributes key-value pairs + """ + # Sort the attributes to ensure consistent keys + sorted_attrs = sorted(attributes.items()) + # Create a string representation of the attributes + return str(sorted_attrs) + + def _normalize_timestamp(self, timestamp_ns: int) -> int: + """ + Normalize a nanosecond timestamp to milliseconds for CloudWatch. + + Args: + timestamp_ns: Timestamp in nanoseconds + + Returns: + Timestamp in milliseconds + """ + # Convert from nanoseconds to milliseconds + return timestamp_ns // 1_000_000 + + def _create_metric_record(self, metric_name: str, metric_unit: str, metric_description: str) -> MetricRecord: + """ + Creates the intermediate metric data structure that standardizes different otel metric representation + and will be used to generate EMF events. The base record + establishes the instrument schema (name/unit/description) that will be populated + with dimensions, timestamps, and values during metric processing. + + Args: + metric_name: Name of the metric + metric_unit: Unit of the metric + metric_description: Description of the metric + + Returns: + A MetricRecord object + """ + return MetricRecord(metric_name, metric_unit, metric_description) + + def _convert_gauge_and_sum(self, metric: Metric, data_point: NumberDataPoint) -> MetricRecord: + """Convert a Gauge or Sum metric datapoint to a metric record. + + Args: + metric: The metric object + data_point: The datapoint to convert + + Returns: + MetricRecord with populated timestamp, attributes, and value + """ + # Create base record + record = self._create_metric_record(metric.name, metric.unit, metric.description) + + # Set timestamp + timestamp_ms = ( + self._normalize_timestamp(data_point.time_unix_nano) + if data_point.time_unix_nano is not None + else int(time.time() * 1000) + ) + record.timestamp = timestamp_ms + + # Set attributes + record.attributes = data_point.attributes + + # Set the value directly for both Gauge and Sum + record.value = data_point.value + + return record + + def _convert_histogram(self, metric: Metric, data_point: Any) -> MetricRecord: + """Convert a Histogram metric datapoint to a metric record. + + https://github.com/open-telemetry/opentelemetry-collector-contrib/blob/main/exporter/awsemfexporter/datapoint.go#L87 + + Args: + metric: The metric object + data_point: The datapoint to convert + + Returns: + MetricRecord with populated timestamp, attributes, and histogram_data + """ + # Create base record + record = self._create_metric_record(metric.name, metric.unit, metric.description) + + # Set timestamp + timestamp_ms = ( + self._normalize_timestamp(data_point.time_unix_nano) + if data_point.time_unix_nano is not None + else int(time.time() * 1000) + ) + record.timestamp = timestamp_ms + + # Set attributes + record.attributes = data_point.attributes + + # For Histogram, set the histogram_data + record.histogram_data = { + "Count": data_point.count, + "Sum": data_point.sum, + "Min": data_point.min, + "Max": data_point.max, + } + return record + + # pylint: disable=too-many-locals + def _convert_exp_histogram(self, metric: Metric, data_point: Any) -> MetricRecord: + """ + Convert an ExponentialHistogram metric datapoint to a metric record. + + This function follows the logic of CalculateDeltaDatapoints in the Go implementation, + converting exponential buckets to their midpoint values. + + Ref: + https://github.com/open-telemetry/opentelemetry-collector-contrib/issues/22626 + + Args: + metric: The metric object + data_point: The datapoint to convert + + Returns: + MetricRecord with populated timestamp, attributes, and exp_histogram_data + """ + + # Create base record + record = self._create_metric_record(metric.name, metric.unit, metric.description) + + # Set timestamp + timestamp_ms = ( + self._normalize_timestamp(data_point.time_unix_nano) + if data_point.time_unix_nano is not None + else int(time.time() * 1000) + ) + record.timestamp = timestamp_ms + + # Set attributes + record.attributes = data_point.attributes + + # Initialize arrays for values and counts + array_values = [] + array_counts = [] + + # Get scale + scale = data_point.scale + # Calculate base using the formula: 2^(2^(-scale)) + base = math.pow(2, math.pow(2, float(-scale))) + + # Process positive buckets + if data_point.positive and data_point.positive.bucket_counts: + positive_offset = getattr(data_point.positive, "offset", 0) + positive_bucket_counts = data_point.positive.bucket_counts + + bucket_begin = 0 + bucket_end = 0 + + for bucket_index, count in enumerate(positive_bucket_counts): + index = bucket_index + positive_offset + + if bucket_begin == 0: + bucket_begin = math.pow(base, float(index)) + else: + bucket_begin = bucket_end + + bucket_end = math.pow(base, float(index + 1)) + + # Calculate midpoint value of the bucket + metric_val = (bucket_begin + bucket_end) / 2 + + # Only include buckets with positive counts + if count > 0: + array_values.append(metric_val) + array_counts.append(float(count)) + + # Process zero bucket + zero_count = getattr(data_point, "zero_count", 0) + if zero_count > 0: + array_values.append(0) + array_counts.append(float(zero_count)) + + # Process negative buckets + if data_point.negative and data_point.negative.bucket_counts: + negative_offset = getattr(data_point.negative, "offset", 0) + negative_bucket_counts = data_point.negative.bucket_counts + + bucket_begin = 0 + bucket_end = 0 + + for bucket_index, count in enumerate(negative_bucket_counts): + index = bucket_index + negative_offset + + if bucket_end == 0: + bucket_end = -math.pow(base, float(index)) + else: + bucket_end = bucket_begin + + bucket_begin = -math.pow(base, float(index + 1)) + + # Calculate midpoint value of the bucket + metric_val = (bucket_begin + bucket_end) / 2 + + # Only include buckets with positive counts + if count > 0: + array_values.append(metric_val) + array_counts.append(float(count)) + + # Set the histogram data in the format expected by CloudWatch EMF + record.exp_histogram_data = { + "Values": array_values, + "Counts": array_counts, + "Count": data_point.count, + "Sum": data_point.sum, + "Max": data_point.max, + "Min": data_point.min, + } + + return record + + def _group_by_attributes_and_timestamp(self, record: MetricRecord) -> Tuple[str, int]: + """Group metric record by attributes and timestamp. + + Args: + record: The metric record + + Returns: + A tuple key for grouping + """ + # Create a key for grouping based on attributes + attrs_key = self._get_attributes_key(record.attributes) + return (attrs_key, record.timestamp) + + def _create_emf_log( + self, metric_records: List[MetricRecord], resource: Resource, timestamp: Optional[int] = None + ) -> Dict: + """ + Create EMF log dictionary from metric records. + + Since metric_records is already grouped by attributes, this function + creates a single EMF log for all records. + """ + # Start with base structure + emf_log = {"_aws": {"Timestamp": timestamp or int(time.time() * 1000), "CloudWatchMetrics": []}} + + # Set with latest EMF version schema + # opentelemetry-collector-contrib/blob/main/exporter/awsemfexporter/metric_translator.go#L414 + emf_log["Version"] = "1" + + # Add resource attributes to EMF log but not as dimensions + # OTel collector EMF Exporter has a resource_to_telemetry_conversion flag that will convert resource attributes + # as regular metric attributes(potential dimensions). However, for this SDK EMF implementation, + # we align with the OpenTelemetry concept that all metric attributes are treated as dimensions. + # And have resource attributes as just additional metadata in EMF, added otel.resource as prefix to distinguish. + if resource and resource.attributes: + for key, value in resource.attributes.items(): + emf_log[f"otel.resource.{key}"] = str(value) + + # Initialize collections for dimensions and metrics + metric_definitions = [] + # Collect attributes from all records (they should be the same for all records in the group) + # Only collect once from the first record and apply to all records + all_attributes = ( + metric_records[0].attributes + if metric_records and len(metric_records) > 0 and metric_records[0].attributes + else {} + ) + + # Process each metric record + for record in metric_records: + + metric_name = self._get_metric_name(record) + + # Skip processing if metric name is None or empty + if not metric_name: + continue + + # Create metric data dict + metric_data = {"Name": metric_name} + + unit = self._get_unit(record) + if unit: + metric_data["Unit"] = unit + + # Process different types of aggregations + if record.exp_histogram_data: + # Base2 Exponential Histogram + emf_log[metric_name] = record.exp_histogram_data + elif record.histogram_data: + # Regular Histogram metrics + emf_log[metric_name] = record.histogram_data + elif record.value is not None: + # Gauge, Sum, and other aggregations + emf_log[metric_name] = record.value + else: + logger.debug("Skipping metric %s as it does not have valid metric value", metric_name) + continue + + # Add to metric definitions list + metric_definitions.append(metric_data) + + # Get dimension names from collected attributes + dimension_names = self._get_dimension_names(all_attributes) + + # Add attribute values to the root of the EMF log + for name, value in all_attributes.items(): + emf_log[name] = str(value) + + # Add the single dimension set to CloudWatch Metrics if we have dimensions and metrics + if dimension_names and metric_definitions: + emf_log["_aws"]["CloudWatchMetrics"].append( + {"Namespace": self.namespace, "Dimensions": [dimension_names], "Metrics": metric_definitions} + ) + + return emf_log + + def _send_log_event(self, log_event: Dict[str, Any]): + """ + Send a log event to CloudWatch Logs using the log client. + + Args: + log_event: The log event to send + """ + self.log_client.send_log_event(log_event) + + # pylint: disable=too-many-nested-blocks,unused-argument,too-many-branches + def export( + self, metrics_data: MetricsData, timeout_millis: Optional[int] = None, **_kwargs: Any + ) -> MetricExportResult: + """ + Export metrics as EMF logs to CloudWatch. + + Groups metrics by attributes and timestamp before creating EMF logs. + + Args: + metrics_data: MetricsData containing resource metrics and scope metrics + timeout_millis: Optional timeout in milliseconds + **kwargs: Additional keyword arguments + + Returns: + MetricExportResult indicating success or failure + """ + try: + if not metrics_data.resource_metrics: + return MetricExportResult.SUCCESS + + # Process all metrics from all resource metrics and scope metrics + for resource_metrics in metrics_data.resource_metrics: + for scope_metrics in resource_metrics.scope_metrics: + # Dictionary to group metrics by attributes and timestamp + grouped_metrics = defaultdict(list) + + # Process all metrics in this scope + for metric in scope_metrics.metrics: + # Skip if metric.data is None or no data_points exists + try: + if not (metric.data and metric.data.data_points): + continue + except AttributeError: + # Metric doesn't have data or data_points attribute + continue + + # Process metrics based on type + metric_type = type(metric.data) + if metric_type in (Gauge, Sum): + for dp in metric.data.data_points: + record = self._convert_gauge_and_sum(metric, dp) + grouped_metrics[self._group_by_attributes_and_timestamp(record)].append(record) + elif metric_type == Histogram: + for dp in metric.data.data_points: + record = self._convert_histogram(metric, dp) + grouped_metrics[self._group_by_attributes_and_timestamp(record)].append(record) + elif metric_type == ExponentialHistogram: + for dp in metric.data.data_points: + record = self._convert_exp_histogram(metric, dp) + grouped_metrics[self._group_by_attributes_and_timestamp(record)].append(record) + else: + logger.debug("Unsupported Metric Type: %s", metric_type) + + # Now process each group separately to create one EMF log per group + for (_, timestamp_ms), metric_records in grouped_metrics.items(): + if not metric_records: + continue + + # Create and send EMF log for this batch of metrics + self._send_log_event( + { + "message": json.dumps( + self._create_emf_log(metric_records, resource_metrics.resource, timestamp_ms) + ), + "timestamp": timestamp_ms, + } + ) + + return MetricExportResult.SUCCESS + # pylint: disable=broad-exception-caught + # capture all types of exceptions to not interrupt the instrumented services + except Exception as error: + logger.error("Failed to export metrics: %s", error) + return MetricExportResult.FAILURE + + def force_flush(self, timeout_millis: int = 10000) -> bool: # pylint: disable=unused-argument + """ + Force flush any pending metrics. + + Args: + timeout_millis: Timeout in milliseconds + + Returns: + True if successful, False otherwise + """ + self.log_client.flush_pending_events() + logger.debug("AwsCloudWatchEmfExporter force flushes the buffered metrics") + return True + + def shutdown(self, timeout_millis: Optional[int] = None, **_kwargs: Any) -> bool: + """ + Shutdown the exporter. + Override to handle timeout and other keyword arguments, but do nothing. + + Args: + timeout_millis: Ignored timeout in milliseconds + **kwargs: Ignored additional keyword arguments + """ + # Force flush any remaining batched events + self.force_flush(timeout_millis) + logger.debug("AwsCloudWatchEmfExporter shutdown called with timeout_millis=%s", timeout_millis) + return True diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/common/aws_auth_session.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/common/aws_auth_session.py index 2c383592b..564bfe9e2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/common/aws_auth_session.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/common/aws_auth_session.py @@ -4,8 +4,9 @@ import logging import requests - -from amazon.opentelemetry.distro._utils import is_installed +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.session import Session _logger = logging.getLogger(__name__) @@ -33,57 +34,36 @@ class AwsAuthSession(requests.Session): service (str): The AWS service name for signing (e.g., "logs" or "xray") """ - def __init__(self, aws_region, service): - - self._has_required_dependencies = False - - # Requires botocore to be installed to sign the headers. However, - # some users might not need to use this authenticator. In order not conflict - # with existing behavior, we check for botocore before initializing this exporter. - - if aws_region and service and is_installed("botocore"): - # pylint: disable=import-outside-toplevel - from botocore import auth, awsrequest, session - - self._boto_auth = auth - self._boto_aws_request = awsrequest - self._boto_session = session.Session() - - self._aws_region = aws_region - self._service = service - self._has_required_dependencies = True - - else: - _logger.error( - "botocore is required to enable SigV4 Authentication. Please install it using `pip install botocore`", - ) + def __init__(self, aws_region: str, service: str, session: Session): + self._aws_region: str = aws_region + self._service: str = service + self._session: Session = session super().__init__() def request(self, method, url, *args, data=None, headers=None, **kwargs): - if self._has_required_dependencies: - - credentials = self._boto_session.get_credentials() - - if credentials is not None: - signer = self._boto_auth.SigV4Auth(credentials, self._service, self._aws_region) - - request = self._boto_aws_request.AWSRequest( - method="POST", - url=url, - data=data, - headers={"Content-Type": "application/x-protobuf"}, - ) + credentials = self._session.get_credentials() + + if credentials: + signer = SigV4Auth(credentials, self._service, self._aws_region) + request = AWSRequest( + method="POST", + url=url, + data=data, + headers={"Content-Type": "application/x-protobuf"}, + ) - try: - signer.add_auth(request) + try: + signer.add_auth(request) - if headers is None: - headers = {} + if headers is None: + headers = {} - headers.update(dict(request.headers)) + headers.update(dict(request.headers)) - except Exception as signing_error: # pylint: disable=broad-except - _logger.error("Failed to sign request: %s", signing_error) + except Exception as signing_error: # pylint: disable=broad-except + _logger.error("Failed to sign request: %s", signing_error) + else: + _logger.error("Failed to load AWS Credentials: %s") return super().request(method=method, url=url, *args, data=data, headers=headers, **kwargs) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/_aws_cw_otlp_batch_log_record_processor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/_aws_cw_otlp_batch_log_record_processor.py new file mode 100644 index 000000000..fe90e1f90 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/_aws_cw_otlp_batch_log_record_processor.py @@ -0,0 +1,258 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License. + +import logging +from typing import Mapping, Optional, Sequence, cast + +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter +from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY, attach, detach, set_value +from opentelemetry.sdk._logs import LogData +from opentelemetry.sdk._logs._internal.export import BatchLogExportStrategy +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.util.types import AnyValue + +_logger = logging.getLogger(__name__) + + +class AwsCloudWatchOtlpBatchLogRecordProcessor(BatchLogRecordProcessor): + """ + Custom implementation of BatchLogRecordProcessor that manages log record batching + with size-based constraints to prevent exceeding AWS CloudWatch Logs OTLP endpoint request size limits. + + This processor still exports all logs up to _MAX_LOG_REQUEST_BYTE_SIZE but rather than doing exactly + one export, we will estimate log sizes and do multiple batch exports + where each exported batch will have an additional constraint: + + If the batch to be exported will have a data size of > 1 MB: + The batch will be split into multiple exports of sub-batches of data size <= 1 MB. + + A unique case is if the sub-batch is of data size > 1 MB, then the sub-batch will have exactly 1 log in it. + """ + + # OTel log events include fixed metadata attributes so the estimated metadata size + # possibly be calculated as this with best efforts: + # service.name (255 chars) + cloud.resource_id (max ARN length) + telemetry.xxx (~20 chars) + + # common attributes (255 chars) + + # scope + flags + traceId + spanId + numeric/timestamp fields + ... + # Example log structure: + # { + # "resource": { + # "attributes": { + # "aws.local.service": "example-service123", + # "telemetry.sdk.language": "python", + # "service.name": "my-application", + # "cloud.resource_id": "example-resource", + # "aws.log.group.names": "example-log-group", + # "aws.ai.agent.type": "default", + # "telemetry.sdk.version": "1.x.x", + # "telemetry.auto.version": "0.x.x", + # "telemetry.sdk.name": "opentelemetry" + # } + # }, + # "scope": {"name": "example.instrumentation.library"}, + # "timeUnixNano": 1234567890123456789, + # "observedTimeUnixNano": 1234567890987654321, + # "severityNumber": 9, + # "body": {...}, + # "attributes": {...}, + # "flags": 1, + # "traceId": "abcd1234efgh5678ijkl9012mnop3456", + # "spanId": "1234abcd5678efgh" + # } + # 2000 might be a bit of an overestimate but it's better to overestimate the size of the log + # and suffer a small performance impact with batching than it is to underestimate and risk + # a large log being dropped when sent to the AWS otlp endpoint. + _BASE_LOG_BUFFER_BYTE_SIZE = 2000 + + _MAX_LOG_REQUEST_BYTE_SIZE = ( + 1048576 # Maximum uncompressed/unserialized bytes / request - + # https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch-OTLPEndpoint.html + ) + + def __init__( + self, + exporter: OTLPAwsLogExporter, + schedule_delay_millis: Optional[float] = None, + max_export_batch_size: Optional[int] = None, + export_timeout_millis: Optional[float] = None, + max_queue_size: Optional[int] = None, + ): + + super().__init__( + exporter=exporter, + schedule_delay_millis=schedule_delay_millis, + max_export_batch_size=max_export_batch_size, + export_timeout_millis=export_timeout_millis, + max_queue_size=max_queue_size, + ) + + self._exporter = exporter + + def _export(self, batch_strategy: BatchLogExportStrategy) -> None: + """ + Explicitly overrides upstream _export method to add AWS CloudWatch size-based batching + See: + https://github.com/open-telemetry/opentelemetry-python/blob/bb21ebd46d070c359eee286c97bdf53bfd06759d/opentelemetry-sdk/src/opentelemetry/sdk/_shared_internal/__init__.py#L143 + + Preserves existing batching behavior but will intermediarly export small log batches if + the size of the data in the batch is estimated to be at or above AWS CloudWatch's + maximum request size limit of 1 MB. + + - Estimated data size of exported batches will typically be <= 1 MB except for the case below: + - If the estimated data size of an exported batch is ever > 1 MB then the batch size is guaranteed to be 1 + """ + with self._export_lock: + iteration = 0 + # We could see concurrent export calls from worker and force_flush. We call _should_export_batch + # once the lock is obtained to see if we still need to make the requested export. + while self._should_export_batch(batch_strategy, iteration): + iteration += 1 + token = attach(set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)) + try: + batch_length = min(self._max_export_batch_size, len(self._queue)) + batch_data_size = 0 + batch = [] + + for _ in range(batch_length): + log_data: LogData = self._queue.pop() + log_size = self._estimate_log_size(log_data) + + if batch and (batch_data_size + log_size > self._MAX_LOG_REQUEST_BYTE_SIZE): + self._exporter.export(batch) + batch_data_size = 0 + batch = [] + + batch_data_size += log_size + batch.append(log_data) + + if batch: + self._exporter.export(batch) + except Exception as exception: # pylint: disable=broad-exception-caught + _logger.exception("Exception while exporting logs: %s", exception) + detach(token) + + def _estimate_log_size(self, log: LogData, depth: int = 3) -> int: # pylint: disable=too-many-branches + """ + Estimates the size in bytes of a log by calculating the size of its body and its attributes + and adding a buffer amount to account for other log metadata information. + + Features: + - Processes complex log structures up to the specified depth limit + - Includes cycle detection to prevent processing the same content more than once + - Returns truncated calculation if depth limit is exceeded + + We set depth to 3 as this is the minimum required depth to estimate our consolidated Gen AI log events: + + Example structure: + { + "output": { + "messages": [ + { + "content": "Hello, World!", + "role": "assistant" + } + ] + }, + "input": { + "messages": [ + { + "content": "Say Hello, World!", + "role": "user" + } + ] + } + } + + Args: + log: The Log object to calculate size for + depth: Maximum depth to traverse in nested structures (default: 3) + + Returns: + int: The estimated size of the log object in bytes + """ + + # Queue contains tuples of (log_content, depth) where: + # - log_content is the current piece of log data being processed + # - depth tracks how many levels deep we've traversed to reach this content + # - body starts at depth 0 since it's an AnyValue object + # - Attributes start at depth -1 since it's a Mapping[str, AnyValue] - when traversed, we will + # start processing its keys at depth 0 + queue = [(log.log_record.body, 0), (log.log_record.attributes, -1)] + + # Track visited complex log contents to avoid calculating the same one more than once + visited = set() + + size: int = self._BASE_LOG_BUFFER_BYTE_SIZE + + while queue: + new_queue = [] + + for data in queue: + # small optimization, can stop calculating the size once it reaches the 1 MB limit. + if size >= self._MAX_LOG_REQUEST_BYTE_SIZE: + return size + + next_val, current_depth = data + + if next_val is None: + continue + + if isinstance(next_val, bytes): + size += len(next_val) + continue + + if isinstance(next_val, (str, float, int, bool)): + size += AwsCloudWatchOtlpBatchLogRecordProcessor._estimate_utf8_size(str(next_val)) + continue + + # next_val must be Sequence["AnyValue"] or Mapping[str, "AnyValue"] + # See: https://github.com/open-telemetry/opentelemetry-python/blob/\ + # 9426d6da834cfb4df7daedd4426bba0aa83165b5/opentelemetry-api/src/opentelemetry/util/types.py#L20 + if current_depth <= depth: + obj_id = id( + next_val + ) # Guaranteed to be unique, see: https://www.w3schools.com/python/ref_func_id.asp + if obj_id in visited: + continue + visited.add(obj_id) + + if isinstance(next_val, Sequence): + for content in next_val: + new_queue.append((cast(AnyValue, content), current_depth + 1)) + + if isinstance(next_val, Mapping): + for key, content in next_val.items(): + size += len(key) + new_queue.append((content, current_depth + 1)) + else: + _logger.debug( + "Max log depth of %s exceeded. Log data size will not be accurately calculated.", depth + ) + + queue = new_queue + + return size + + @staticmethod + def _estimate_utf8_size(s: str): + ascii_count = 0 + non_ascii_count = 0 + + for char in s: + if ord(char) < 128: + ascii_count += 1 + else: + non_ascii_count += 1 + + # Estimate: ASCII chars (1 byte) + upper bound of non-ASCII chars 4 bytes + return ascii_count + (non_ascii_count * 4) + + # Only export the logs once to avoid the race condition of the worker thread and force flush thread + # https://github.com/open-telemetry/opentelemetry-python/issues/3193 + # https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-sdk/src/opentelemetry/sdk/_shared_internal/__init__.py#L199 + def force_flush(self, timeout_millis: Optional[int] = None) -> bool: + if self._shutdown: + return False + self._export(BatchLogExportStrategy.EXPORT_AT_LEAST_ONE_BATCH) + return True diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py deleted file mode 100644 index 8feada9a0..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -from typing import Mapping, Optional, Sequence, cast - -from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter -from opentelemetry.context import ( - _SUPPRESS_INSTRUMENTATION_KEY, - attach, - detach, - set_value, -) -from opentelemetry.sdk._logs import LogData -from opentelemetry.sdk._logs._internal.export import BatchLogExportStrategy -from opentelemetry.sdk._logs.export import BatchLogRecordProcessor -from opentelemetry.util.types import AnyValue - -_logger = logging.getLogger(__name__) - - -class AwsBatchLogRecordProcessor(BatchLogRecordProcessor): - _BASE_LOG_BUFFER_BYTE_SIZE = ( - 2000 # Buffer size in bytes to account for log metadata not included in the body size calculation - ) - _MAX_LOG_REQUEST_BYTE_SIZE = ( - 1048576 # https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch-OTLPEndpoint.html - ) - - def __init__( - self, - exporter: OTLPAwsLogExporter, - schedule_delay_millis: Optional[float] = None, - max_export_batch_size: Optional[int] = None, - export_timeout_millis: Optional[float] = None, - max_queue_size: Optional[int] = None, - ): - - super().__init__( - exporter=exporter, - schedule_delay_millis=schedule_delay_millis, - max_export_batch_size=max_export_batch_size, - export_timeout_millis=export_timeout_millis, - max_queue_size=max_queue_size, - ) - - self._exporter = exporter - - # https://github.com/open-telemetry/opentelemetry-python/blob/main/opentelemetry-sdk/src/opentelemetry/sdk/_shared_internal/__init__.py#L143 - def _export(self, batch_strategy: BatchLogExportStrategy) -> None: - """ - Preserves existing batching behavior but will intermediarly export small log batches if - the size of the data in the batch is at orabove AWS CloudWatch's maximum request size limit of 1 MB. - - - Data size of exported batches will ALWAYS be <= 1 MB except for the case below: - - If the data size of an exported batch is ever > 1 MB then the batch size is guaranteed to be 1 - """ - with self._export_lock: - iteration = 0 - # We could see concurrent export calls from worker and force_flush. We call _should_export_batch - # once the lock is obtained to see if we still need to make the requested export. - while self._should_export_batch(batch_strategy, iteration): - iteration += 1 - token = attach(set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)) - try: - batch_length = min(self._max_export_batch_size, len(self._queue)) - batch_data_size = 0 - batch = [] - - for _ in range(batch_length): - log_data: LogData = self._queue.pop() - log_size = self._BASE_LOG_BUFFER_BYTE_SIZE + self._get_any_value_size(log_data.log_record.body) - - if batch and (batch_data_size + log_size > self._MAX_LOG_REQUEST_BYTE_SIZE): - # if batch_data_size > MAX_LOG_REQUEST_BYTE_SIZE then len(batch) == 1 - if batch_data_size > self._MAX_LOG_REQUEST_BYTE_SIZE: - if self._is_gen_ai_log(batch[0]): - self._exporter.set_gen_ai_log_flag() - - self._exporter.export(batch) - batch_data_size = 0 - batch = [] - - batch_data_size += log_size - batch.append(log_data) - - if batch: - # if batch_data_size > MAX_LOG_REQUEST_BYTE_SIZE then len(batch) == 1 - if batch_data_size > self._MAX_LOG_REQUEST_BYTE_SIZE: - if self._is_gen_ai_log(batch[0]): - self._exporter.set_gen_ai_log_flag() - - self._exporter.export(batch) - except Exception as e: # pylint: disable=broad-exception-caught - _logger.exception("Exception while exporting logs: " + str(e)) - detach(token) - - def _get_any_value_size(self, val: AnyValue, depth: int = 3) -> int: - """ - Only used to indicate whether we should export a batch log size of 1 or not. - Calculates the size in bytes of an AnyValue object. - Will processs complex AnyValue structures up to the specified depth limit. - If the depth limit of the AnyValue structure is exceeded, returns 0. - - Args: - val: The AnyValue object to calculate size for - depth: Maximum depth to traverse in nested structures (default: 3) - - Returns: - int: Total size of the AnyValue object in bytes - """ - # Use a stack to prevent excessive recursive calls. - stack = [(val, 0)] - size: int = 0 - - while stack: - # small optimization. We can stop calculating the size once it reaches the 1 MB limit. - if size >= self._MAX_LOG_REQUEST_BYTE_SIZE: - return size - - next_val, current_depth = stack.pop() - - if isinstance(next_val, (str, bytes)): - size += len(next_val) - continue - - if isinstance(next_val, bool): - size += 4 if next_val else 5 - continue - - if isinstance(next_val, (float, int)): - size += len(str(next_val)) - continue - - if current_depth <= depth: - if isinstance(next_val, Sequence): - for content in next_val: - stack.append((cast(AnyValue, content), current_depth + 1)) - - if isinstance(next_val, Mapping): - for key, content in next_val.items(): - size += len(key) - stack.append((content, current_depth + 1)) - else: - _logger.debug("Max log depth exceeded. Log data size will not be accurately calculated.") - return 0 - - return size - - @staticmethod - def _is_gen_ai_log(log_data: LogData) -> bool: - """ - Is the log a Gen AI log event? - """ - gen_ai_instrumentations = { - "openinference.instrumentation.langchain", - "openinference.instrumentation.crewai", - "opentelemetry.instrumentation.langchain", - "crewai.telemetry", - "openlit.otel.tracing", - } - - return log_data.instrumentation_scope.name in gen_ai_instrumentations diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py index 64203b434..4ed3649c3 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter.py @@ -1,43 +1,48 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License. import gzip import logging +import random from io import BytesIO -from time import sleep +from threading import Event +from time import time from typing import Dict, Optional, Sequence -import requests +from botocore.session import Session +from requests import Response +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.structures import CaseInsensitiveDict from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession -from opentelemetry.exporter.otlp.proto.common._internal import ( - _create_exp_backoff_generator, -) from opentelemetry.exporter.otlp.proto.common._log_encoder import encode_logs from opentelemetry.exporter.otlp.proto.http import Compression from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter -from opentelemetry.sdk._logs import ( - LogData, -) -from opentelemetry.sdk._logs.export import ( - LogExportResult, -) +from opentelemetry.sdk._logs import LogData +from opentelemetry.sdk._logs.export import LogExportResult _logger = logging.getLogger(__name__) +_MAX_RETRYS = 6 class OTLPAwsLogExporter(OTLPLogExporter): - _LARGE_LOG_HEADER = "x-aws-truncatable-fields" - _LARGE_GEN_AI_LOG_PATH_HEADER = ( - "\\$['resourceLogs'][0]['scopeLogs'][0]['logRecords'][0]['body']" - "['kvlistValue']['values'][*]['value']['kvlistValue']['values'][*]" - "['value']['arrayValue']['values'][*]['kvlistValue']['values'][*]" - "['value']['stringValue']" - ) - _RETRY_AFTER_HEADER = "Retry-After" # https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling + """ + This exporter extends the functionality of the OTLPLogExporter to allow logs to be exported + to the CloudWatch Logs OTLP endpoint https://logs.[AWSRegion].amazonaws.com/v1/logs. Utilizes the aws-sdk + library to sign and directly inject SigV4 Authentication to the exported request's headers. + + See: https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch-OTLPEndpoint.html + """ + + _RETRY_AFTER_HEADER = "Retry-After" # See: https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling def __init__( self, + aws_region: str, + session: Session, + log_group: Optional[str] = None, + log_stream: Optional[str] = None, endpoint: Optional[str] = None, certificate_file: Optional[str] = None, client_key_file: Optional[str] = None, @@ -45,11 +50,14 @@ def __init__( headers: Optional[Dict[str, str]] = None, timeout: Optional[int] = None, ): - self._gen_ai_log_flag = False - self._aws_region = None + self._aws_region = aws_region - if endpoint: - self._aws_region = endpoint.split(".")[1] + if log_group and log_stream: + log_headers = {"x-aws-log-group": log_group, "x-aws-log-stream": log_stream} + if headers: + headers.update(log_headers) + else: + headers = log_headers OTLPLogExporter.__init__( self, @@ -60,27 +68,21 @@ def __init__( headers, timeout, compression=Compression.Gzip, - session=AwsAuthSession(aws_region=self._aws_region, service="logs"), + session=AwsAuthSession(session=session, aws_region=self._aws_region, service="logs"), ) + self._shutdown_event = Event() - # https://github.com/open-telemetry/opentelemetry-python/blob/main/exporter/opentelemetry-exporter-otlp-proto-http/src/opentelemetry/exporter/otlp/proto/http/_log_exporter/__init__.py#L167 def export(self, batch: Sequence[LogData]) -> LogExportResult: """ - Exports the given batch of OTLP log data. - Behaviors of how this export will work - - - 1. Always compresses the serialized data into gzip before sending. + Exports log batch with AWS-specific enhancements over the base OTLPLogExporter. - 2. If self._gen_ai_log_flag is enabled, the log data is > 1 MB a - and the assumption is that the log is a normalized gen.ai LogEvent. - - inject the {LARGE_LOG_HEADER} into the header. + Key differences from upstream OTLPLogExporter: + 1. Respects Retry-After header from server responses for proper throttling + 2. Treats HTTP 429 (Too Many Requests) as a retryable exception + 3. Always compresses data with gzip before sending - 3. Retry behavior is now the following: - - if the response contains a status code that is retryable and the response contains Retry-After in its - headers, the serialized data will be exported after that set delay - - - if the response does not contain that Retry-After header, default back to the current iteration of the - exponential backoff delay + Upstream implementation does not support Retry-After header: + https://github.com/open-telemetry/opentelemetry-python/blob/acae2c232b101d3e447a82a7161355d66aa06fa2/exporter/opentelemetry-exporter-otlp-proto-http/src/opentelemetry/exporter/otlp/proto/http/_log_exporter/__init__.py#L167 """ if self._shutdown: @@ -88,97 +90,93 @@ def export(self, batch: Sequence[LogData]) -> LogExportResult: return LogExportResult.FAILURE serialized_data = encode_logs(batch).SerializeToString() - gzip_data = BytesIO() with gzip.GzipFile(fileobj=gzip_data, mode="w") as gzip_stream: gzip_stream.write(serialized_data) - data = gzip_data.getvalue() - backoff = _create_exp_backoff_generator(max_value=self._MAX_RETRY_TIMEOUT) + deadline_sec = time() + self._timeout + retry_num = 0 + # This loop will eventually terminate because: + # 1) The export request will eventually either succeed or fail permanently + # 2) Maximum retries (_MAX_RETRYS = 6) will be reached + # 3) Deadline timeout will be exceeded + # 4) Non-retryable errors (4xx except 429) immediately exit the loop while True: - resp = self._send(data) + resp = self._send(data, deadline_sec - time()) if resp.ok: return LogExportResult.SUCCESS - if not self._retryable(resp): + backoff_seconds = self._get_retry_delay_sec(resp.headers, retry_num) + is_retryable = self._retryable(resp) + + if not is_retryable or retry_num + 1 == _MAX_RETRYS or backoff_seconds > (deadline_sec - time()): _logger.error( "Failed to export logs batch code: %s, reason: %s", resp.status_code, resp.text, ) - self._gen_ai_log_flag = False - return LogExportResult.FAILURE - - # https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling - maybe_retry_after = resp.headers.get(self._RETRY_AFTER_HEADER, None) - - # Set the next retry delay to the value of the Retry-After response in the headers. - # If Retry-After is not present in the headers, default to the next iteration of the - # exponential backoff strategy. - - delay = self._parse_retryable_header(maybe_retry_after) - - if delay == -1: - delay = next(backoff, self._MAX_RETRY_TIMEOUT) - - if delay == self._MAX_RETRY_TIMEOUT: - _logger.error( - "Transient error %s encountered while exporting logs batch. " - "No Retry-After header found and all backoff retries exhausted. " - "Logs will not be exported.", - resp.reason, - ) - self._gen_ai_log_flag = False return LogExportResult.FAILURE _logger.warning( - "Transient error %s encountered while exporting logs batch, retrying in %ss.", + "Transient error %s encountered while exporting logs batch, retrying in %.2fs.", resp.reason, - delay, + backoff_seconds, ) + # Use interruptible sleep that can be interrupted by shutdown + if self._shutdown_event.wait(backoff_seconds): + _logger.info("Export interrupted by shutdown") + return LogExportResult.FAILURE - sleep(delay) + retry_num += 1 - def set_gen_ai_log_flag(self): - """ - Sets a flag that indicates the current log batch contains - a generative AI log record that exceeds the CloudWatch Logs size limit (1MB). - """ - self._gen_ai_log_flag = True + def shutdown(self) -> None: + """Shutdown the exporter and interrupt any ongoing waits.""" + self._shutdown_event.set() + return super().shutdown() - def _send(self, serialized_data: bytes): + def _send(self, serialized_data: bytes, timeout_sec: float): try: response = self._session.post( url=self._endpoint, - headers={self._LARGE_LOG_HEADER: self._LARGE_GEN_AI_LOG_PATH_HEADER} if self._gen_ai_log_flag else None, data=serialized_data, verify=self._certificate_file, - timeout=self._timeout, + timeout=timeout_sec, cert=self._client_cert, ) return response - except ConnectionError: + except RequestsConnectionError: response = self._session.post( url=self._endpoint, - headers={self._LARGE_LOG_HEADER: self._LARGE_GEN_AI_LOG_PATH_HEADER} if self._gen_ai_log_flag else None, data=serialized_data, verify=self._certificate_file, - timeout=self._timeout, + timeout=timeout_sec, cert=self._client_cert, ) return response @staticmethod - def _retryable(resp: requests.Response) -> bool: + def _retryable(resp: Response) -> bool: """ - Is it a retryable response? + Logic based on https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling """ + # See: https://opentelemetry.io/docs/specs/otlp/#otlphttp-throttling return resp.status_code in (429, 503) or OTLPLogExporter._retryable(resp) + def _get_retry_delay_sec(self, headers: CaseInsensitiveDict, retry_num: int) -> float: + """ + Get retry delay in seconds from headers or backoff strategy. + """ + # Check for Retry-After header first, then use exponential backoff with jitter + retry_after_delay = self._parse_retryable_header(headers.get(self._RETRY_AFTER_HEADER)) + if retry_after_delay > -1: + return retry_after_delay + # multiplying by a random number between .8 and 1.2 introduces a +/-20% jitter to each backoff. + return 2**retry_num * random.uniform(0.8, 1.2) + @staticmethod def _parse_retryable_header(retry_header: Optional[str]) -> float: """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/traces/otlp_aws_span_exporter.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/traces/otlp_aws_span_exporter.py index 47a4d693e..3589121d9 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/traces/otlp_aws_span_exporter.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/exporter/otlp/aws/traces/otlp_aws_span_exporter.py @@ -1,17 +1,37 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +import logging +from typing import Dict, Optional, Sequence +from botocore.session import Session + +from amazon.opentelemetry.distro._utils import is_agent_observability_enabled from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession +from amazon.opentelemetry.distro.llo_handler import LLOHandler +from opentelemetry._logs import get_logger_provider from opentelemetry.exporter.otlp.proto.http import Compression from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk._logs import LoggerProvider +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExportResult + +_logger = logging.getLogger(__name__) class OTLPAwsSpanExporter(OTLPSpanExporter): + """ + This exporter extends the functionality of the OTLPSpanExporter to allow spans to be exported + to the XRay OTLP endpoint https://xray.[AWSRegion].amazonaws.com/v1/traces. Utilizes the + AwsAuthSession to sign and directly inject SigV4 Authentication to the exported request's headers. + + See: https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch-OTLPEndpoint.html + """ + def __init__( self, + aws_region: str, + session: Session, endpoint: Optional[str] = None, certificate_file: Optional[str] = None, client_key_file: Optional[str] = None, @@ -21,11 +41,9 @@ def __init__( compression: Optional[Compression] = None, logger_provider: Optional[LoggerProvider] = None, ): - self._aws_region = None + self._aws_region = aws_region self._logger_provider = logger_provider - - if endpoint: - self._aws_region = endpoint.split(".")[1] + self._llo_handler = None OTLPSpanExporter.__init__( self, @@ -36,5 +54,31 @@ def __init__( headers, timeout, compression, - session=AwsAuthSession(aws_region=self._aws_region, service="xray"), + session=AwsAuthSession(session=session, aws_region=self._aws_region, service="xray"), ) + + def _ensure_llo_handler(self): + """Lazily initialize LLO handler when needed to avoid initialization order issues""" + if self._llo_handler is None and is_agent_observability_enabled(): + if self._logger_provider is None: + try: + self._logger_provider = get_logger_provider() + except Exception as exc: # pylint: disable=broad-exception-caught + _logger.debug("Failed to get logger provider: %s", exc) + return False + + if self._logger_provider: + self._llo_handler = LLOHandler(self._logger_provider) + return True + + return self._llo_handler is not None + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + try: + if is_agent_observability_enabled() and self._ensure_llo_handler(): + llo_processed_spans = self._llo_handler.process_spans(spans) + return super().export(llo_processed_spans) + except Exception: # pylint: disable=broad-exception-caught + return SpanExportResult.FAILURE + + return super().export(spans) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/llo_handler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/llo_handler.py new file mode 100644 index 000000000..42506d905 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/llo_handler.py @@ -0,0 +1,557 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, TypedDict + +from opentelemetry._events import Event +from opentelemetry.attributes import BoundedAttributes +from opentelemetry.sdk._events import EventLoggerProvider +from opentelemetry.sdk._logs import LoggerProvider +from opentelemetry.sdk.trace import Event as SpanEvent +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.util import types + +ROLE_SYSTEM = "system" +ROLE_USER = "user" +ROLE_ASSISTANT = "assistant" + +_logger = logging.getLogger(__name__) + + +class PatternType(str, Enum): + """Types of LLO attribute patterns.""" + + INDEXED = "indexed" + DIRECT = "direct" + + +class PatternConfig(TypedDict, total=False): + """Configuration for an LLO pattern.""" + + type: PatternType + regex: Optional[str] + role_key: Optional[str] + role: Optional[str] + default_role: Optional[str] + source: str + + +LLO_PATTERNS: Dict[str, PatternConfig] = { + "gen_ai.prompt.{index}.content": { + "type": PatternType.INDEXED, + "regex": r"^gen_ai\.prompt\.(\d+)\.content$", + "role_key": "gen_ai.prompt.{index}.role", + "default_role": "unknown", + "source": "prompt", + }, + "gen_ai.completion.{index}.content": { + "type": PatternType.INDEXED, + "regex": r"^gen_ai\.completion\.(\d+)\.content$", + "role_key": "gen_ai.completion.{index}.role", + "default_role": "unknown", + "source": "completion", + }, + "llm.input_messages.{index}.message.content": { + "type": PatternType.INDEXED, + "regex": r"^llm\.input_messages\.(\d+)\.message\.content$", + "role_key": "llm.input_messages.{index}.message.role", + "default_role": ROLE_USER, + "source": "input", + }, + "llm.output_messages.{index}.message.content": { + "type": PatternType.INDEXED, + "regex": r"^llm\.output_messages\.(\d+)\.message\.content$", + "role_key": "llm.output_messages.{index}.message.role", + "default_role": ROLE_ASSISTANT, + "source": "output", + }, + "traceloop.entity.input": { + "type": PatternType.DIRECT, + "role": ROLE_USER, + "source": "input", + }, + "traceloop.entity.output": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "output", + }, + "crewai.crew.tasks_output": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "output", + }, + "crewai.crew.result": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "result", + }, + "gen_ai.prompt": { + "type": PatternType.DIRECT, + "role": ROLE_USER, + "source": "prompt", + }, + "gen_ai.completion": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "completion", + }, + "gen_ai.content.revised_prompt": { + "type": PatternType.DIRECT, + "role": ROLE_SYSTEM, + "source": "prompt", + }, + "gen_ai.agent.actual_output": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "output", + }, + "gen_ai.agent.human_input": { + "type": PatternType.DIRECT, + "role": ROLE_USER, + "source": "input", + }, + "input.value": { + "type": PatternType.DIRECT, + "role": ROLE_USER, + "source": "input", + }, + "output.value": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "output", + }, + "system_prompt": { + "type": PatternType.DIRECT, + "role": ROLE_SYSTEM, + "source": "prompt", + }, + "tool.result": { + "type": PatternType.DIRECT, + "role": ROLE_ASSISTANT, + "source": "output", + }, + "llm.prompts": { + "type": PatternType.DIRECT, + "role": ROLE_USER, + "source": "prompt", + }, +} + + +class LLOHandler: + """ + Utility class for handling Large Language Objects (LLO) in OpenTelemetry spans. + + LLOHandler performs three primary functions: + 1. Identifies Large Language Objects (LLO) content in spans + 2. Extracts and transforms these attributes into OpenTelemetry Gen AI Events + 3. Filters LLO from spans to maintain privacy and reduce span size + + The handler uses a configuration-driven approach with a pattern registry that defines + all supported LLO attribute patterns and their extraction rules. This makes it easy + to add support for new frameworks without modifying the core logic. + """ + + def __init__(self, logger_provider: LoggerProvider): + """ + Initialize an LLOHandler with the specified logger provider. + + This constructor sets up the event logger provider and compiles patterns + from the pattern registry for efficient matching. + + Args: + logger_provider: The OpenTelemetry LoggerProvider used for emitting events. + Global LoggerProvider instance injected from our AwsOpenTelemetryConfigurator + """ + self._logger_provider = logger_provider + self._event_logger_provider = EventLoggerProvider(logger_provider=self._logger_provider) + + self._build_pattern_matchers() + + def _build_pattern_matchers(self) -> None: + """ + Build efficient pattern matching structures from the pattern registry. + + Creates: + - Set of exact match patterns for O(1) lookups + - List of compiled regex patterns for indexed patterns + - Mapping of patterns to their configurations + """ + self._exact_match_patterns = set() + self._regex_patterns = [] + self._pattern_configs = {} + + for pattern_key, config in LLO_PATTERNS.items(): + if config["type"] == PatternType.DIRECT: + self._exact_match_patterns.add(pattern_key) + self._pattern_configs[pattern_key] = config + elif config["type"] == PatternType.INDEXED: + if regex_str := config.get("regex"): + compiled_regex = re.compile(regex_str) + self._regex_patterns.append((compiled_regex, pattern_key, config)) + + def _collect_all_llo_messages(self, span: ReadableSpan, attributes: types.Attributes) -> List[Dict[str, Any]]: + """ + Collect all LLO messages from attributes using the pattern registry. + + This is the main collection method that processes all patterns defined + in the registry and extracts messages accordingly. + + Args: + span: The source ReadableSpan containing the attributes + attributes: Dictionary of attributes to process + + Returns: + List[Dict[str, Any]]: List of message dictionaries with 'content', 'role', and 'source' keys + """ + messages = [] + + if attributes is None: + return messages + + for attr_key, value in attributes.items(): + if attr_key in self._exact_match_patterns: + config = self._pattern_configs[attr_key] + messages.append( + {"content": value, "role": config.get("role", "unknown"), "source": config.get("source", "unknown")} + ) + + indexed_messages = self._collect_indexed_messages(attributes) + messages.extend(indexed_messages) + + return messages + + def _collect_indexed_messages(self, attributes: types.Attributes) -> List[Dict[str, Any]]: + """ + Collect messages from indexed patterns (e.g., gen_ai.prompt.0.content). + + Handles patterns with numeric indices and their associated role attributes. + + Args: + attributes: Dictionary of attributes to process + + Returns: + List[Dict[str, Any]]: List of message dictionaries + """ + indexed_messages = {} + + if attributes is None: + return [] + + for attr_key, value in attributes.items(): + for regex, pattern_key, config in self._regex_patterns: + match = regex.match(attr_key) + if match: + index = int(match.group(1)) + + role = config.get("default_role", "unknown") + if role_key_template := config.get("role_key"): + role_key = role_key_template.replace("{index}", str(index)) + role = attributes.get(role_key, role) + + key = (pattern_key, index) + indexed_messages[key] = {"content": value, "role": role, "source": config.get("source", "unknown")} + break + + sorted_keys = sorted(indexed_messages.keys(), key=lambda k: (k[0], k[1])) + return [indexed_messages[k] for k in sorted_keys] + + def _collect_llo_attributes_from_span(self, span: ReadableSpan) -> Dict[str, Any]: + """ + Collect all LLO attributes from a span's attributes and events. + + Args: + span: The span to collect LLO attributes from + + Returns: + Dictionary of all LLO attributes found in the span + """ + all_llo_attributes = {} + + # Collect from span attributes + if span.attributes is not None: + for key, value in span.attributes.items(): + if self._is_llo_attribute(key): + all_llo_attributes[key] = value + + # Collect from span events + if span.events: + for event in span.events: + if event.attributes: + for key, value in event.attributes.items(): + if self._is_llo_attribute(key): + all_llo_attributes[key] = value + + return all_llo_attributes + + # pylint: disable-next=no-self-use + def _update_span_attributes(self, span: ReadableSpan, filtered_attributes: types.Attributes) -> None: + """ + Update span attributes, preserving BoundedAttributes if present. + + Args: + span: The span to update + filtered_attributes: The filtered attributes to set + """ + if filtered_attributes is not None and isinstance(span.attributes, BoundedAttributes): + span._attributes = BoundedAttributes( + maxlen=span.attributes.maxlen, + attributes=filtered_attributes, + immutable=span.attributes._immutable, + max_value_len=span.attributes.max_value_len, + ) + else: + span._attributes = filtered_attributes + + def process_spans(self, spans: Sequence[ReadableSpan]) -> List[ReadableSpan]: + """ + Processes a sequence of spans to extract and filter LLO attributes. + + For each span, this method: + 1. Collects all LLO attributes from span attributes and all span events + 2. Emits a single consolidated Gen AI Event with all collected LLO content + 3. Filters out LLO attributes from the span and its events to maintain privacy + 4. Preserves non-LLO attributes in the span + + Handles LLO attributes from multiple frameworks: + - Traceloop (indexed prompt/completion patterns and entity input/output) + - OpenLit (direct prompt/completion patterns, including from span events) + - OpenInference (input/output values and structured messages) + - Strands SDK (system prompts and tool results) + - CrewAI (tasks output and results) + + Args: + spans: A sequence of OpenTelemetry ReadableSpan objects to process + + Returns: + List[ReadableSpan]: Modified spans with LLO attributes removed + """ + modified_spans = [] + + for span in spans: + # Collect all LLO attributes from both span attributes and events + all_llo_attributes = self._collect_llo_attributes_from_span(span) + + # Emit a single consolidated event if we found any LLO attributes + if all_llo_attributes: + self._emit_llo_attributes(span, all_llo_attributes) + + # Filter span attributes + filtered_attributes = None + if span.attributes is not None: + filtered_attributes = self._filter_attributes(span.attributes) + + # Update span attributes + self._update_span_attributes(span, filtered_attributes) + + # Filter span events + self._filter_span_events(span) + + modified_spans.append(span) + + return modified_spans + + def _filter_span_events(self, span: ReadableSpan) -> None: + """ + Filter LLO attributes from span events. + + This method removes LLO attributes from event attributes while preserving + the event structure and non-LLO attributes. + + Args: + span: The ReadableSpan to filter events for + + Returns: + None: The span is modified in-place + """ + if not span.events: + return + + updated_events = [] + + for event in span.events: + if not event.attributes: + updated_events.append(event) + continue + + updated_event_attributes = self._filter_attributes(event.attributes) + + if updated_event_attributes is not None and len(updated_event_attributes) != len(event.attributes): + limit = None + if isinstance(event.attributes, BoundedAttributes): + limit = event.attributes.maxlen + + updated_event = SpanEvent( + name=event.name, attributes=updated_event_attributes, timestamp=event.timestamp, limit=limit + ) + + updated_events.append(updated_event) + else: + updated_events.append(event) + + span._events = updated_events + + # pylint: disable-next=no-self-use + def _group_messages_by_type(self, messages: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, str]]]: + """ + Group messages into input and output categories based on role and source. + + Args: + messages: List of message dictionaries with 'role', 'content', and 'source' keys + + Returns: + Dictionary with 'input' and 'output' lists of messages + """ + input_messages = [] + output_messages = [] + + for message in messages: + role = message.get("role", "unknown") + content = message.get("content", "") + formatted_message = {"role": role, "content": content} + + if role in [ROLE_SYSTEM, ROLE_USER]: + input_messages.append(formatted_message) + elif role == ROLE_ASSISTANT: + output_messages.append(formatted_message) + else: + # Route based on source for non-standard roles + if any(key in message.get("source", "") for key in ["completion", "output", "result"]): + output_messages.append(formatted_message) + else: + input_messages.append(formatted_message) + + return {"input": input_messages, "output": output_messages} + + def _emit_llo_attributes( + self, span: ReadableSpan, attributes: types.Attributes, event_timestamp: Optional[int] = None + ) -> None: + """ + Extract LLO attributes and emit them as a single consolidated Gen AI Event. + + This method: + 1. Collects all LLO attributes using the pattern registry + 2. Groups them into input and output messages + 3. Emits one event per span containing all LLO content + + The event body format: + { + "input": { + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."} + ] + }, + "output": { + "messages": [ + {"role": "assistant", "content": "..."} + ] + } + } + + Args: + span: The source ReadableSpan containing the attributes + attributes: Dictionary of attributes to process + event_timestamp: Optional timestamp to override span timestamps + + Returns: + None: Event is emitted via the event logger + """ + if attributes is None: + return + has_llo_attrs = any(self._is_llo_attribute(key) for key in attributes) + if not has_llo_attrs: + return + + all_messages = self._collect_all_llo_messages(span, attributes) + if not all_messages: + return + + # Group messages into input/output categories + grouped_messages = self._group_messages_by_type(all_messages) + + # Build event body + event_body = {} + if grouped_messages["input"]: + event_body["input"] = {"messages": grouped_messages["input"]} + if grouped_messages["output"]: + event_body["output"] = {"messages": grouped_messages["output"]} + + if not event_body: + return + + timestamp = event_timestamp if event_timestamp is not None else span.end_time + event_logger = self._event_logger_provider.get_event_logger(span.instrumentation_scope.name) + + event_attributes = {} + if span.attributes and "session.id" in span.attributes: + event_attributes["session.id"] = span.attributes["session.id"] + + event = Event( + name=span.instrumentation_scope.name, + timestamp=timestamp, + body=event_body, + attributes=event_attributes if event_attributes else None, + trace_id=span.context.trace_id, + span_id=span.context.span_id, + trace_flags=span.context.trace_flags, + ) + + event_logger.emit(event) + _logger.debug("Emitted Gen AI Event with input/output message format") + + def _filter_attributes(self, attributes: types.Attributes) -> types.Attributes: + """ + Create a new attributes dictionary with LLO attributes removed. + + This method creates a new dictionary containing only non-LLO attributes, + preserving the original values while filtering out sensitive LLO content. + This helps maintain privacy and reduces the size of spans. + + Args: + attributes: Original dictionary of span or event attributes + + Returns: + types.Attributes: New dictionary with LLO attributes removed, or None if input is None + """ + has_llo_attrs = False + for key in attributes: + if self._is_llo_attribute(key): + has_llo_attrs = True + break + + if not has_llo_attrs: + return attributes + + if attributes is None: + return None + + filtered_attributes = {} + for key, value in attributes.items(): + if not self._is_llo_attribute(key): + filtered_attributes[key] = value + + return filtered_attributes + + def _is_llo_attribute(self, key: str) -> bool: + """ + Determine if an attribute key contains LLO content based on pattern matching. + + Uses the pattern registry to check if a key matches any LLO pattern. + + Args: + key: The attribute key to check + + Returns: + bool: True if the key matches any LLO pattern, False otherwise + """ + if key in self._exact_match_patterns: + return True + + for regex, _, _ in self._regex_patterns: + if regex.match(key): + return True + + return False diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_bedrock_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_bedrock_patches.py index a25e55330..549154771 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_bedrock_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_bedrock_patches.py @@ -2,13 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import abc import inspect -import io -import json import logging -import math -from typing import Any, Dict, Optional - -from botocore.response import StreamingBody +from typing import Dict, Optional from amazon.opentelemetry.distro._aws_attribute_keys import ( AWS_BEDROCK_AGENT_ID, @@ -17,20 +12,11 @@ AWS_BEDROCK_GUARDRAIL_ID, AWS_BEDROCK_KNOWLEDGE_BASE_ID, ) -from amazon.opentelemetry.distro._aws_span_processing_util import ( - GEN_AI_REQUEST_MAX_TOKENS, - GEN_AI_REQUEST_MODEL, - GEN_AI_REQUEST_TEMPERATURE, - GEN_AI_REQUEST_TOP_P, - GEN_AI_RESPONSE_FINISH_REASONS, - GEN_AI_SYSTEM, - GEN_AI_USAGE_INPUT_TOKENS, - GEN_AI_USAGE_OUTPUT_TOKENS, -) from opentelemetry.instrumentation.botocore.extensions.types import ( _AttributeMapT, _AwsSdkCallContext, _AwsSdkExtension, + _BotocoreInstrumentorContext, _BotoResultT, ) from opentelemetry.trace.span import Span @@ -192,7 +178,7 @@ def extract_attributes(self, attributes: _AttributeMapT): if request_param_value: attributes[attribute_key] = request_param_value - def on_success(self, span: Span, result: _BotoResultT): + def on_success(self, span: Span, result: _BotoResultT, instrumentor_context: _BotocoreInstrumentorContext): if self._operation_class is None: return @@ -229,7 +215,7 @@ class _BedrockExtension(_AwsSdkExtension): """ # pylint: disable=no-self-use - def on_success(self, span: Span, result: _BotoResultT): + def on_success(self, span: Span, result: _BotoResultT, instrumentor_context: _BotocoreInstrumentorContext): # _GUARDRAIL_ID can only be retrieved from the response, not from the request guardrail_id = result.get(_GUARDRAIL_ID) if guardrail_id: @@ -244,205 +230,3 @@ def on_success(self, span: Span, result: _BotoResultT): AWS_BEDROCK_GUARDRAIL_ARN, guardrail_arn, ) - - -class _BedrockRuntimeExtension(_AwsSdkExtension): - """ - This class is an extension for - Amazon Bedrock Runtime. - """ - - def extract_attributes(self, attributes: _AttributeMapT): - attributes[GEN_AI_SYSTEM] = _AWS_BEDROCK_SYSTEM - - model_id = self._call_context.params.get(_MODEL_ID) - if model_id: - attributes[GEN_AI_REQUEST_MODEL] = model_id - - # Get the request body if it exists - body = self._call_context.params.get("body") - if body: - try: - request_body = json.loads(body) - - if "amazon.titan" in model_id: - self._extract_titan_attributes(attributes, request_body) - if "amazon.nova" in model_id: - self._extract_nova_attributes(attributes, request_body) - elif "anthropic.claude" in model_id: - self._extract_claude_attributes(attributes, request_body) - elif "meta.llama" in model_id: - self._extract_llama_attributes(attributes, request_body) - elif "cohere.command" in model_id: - self._extract_cohere_attributes(attributes, request_body) - elif "ai21.jamba" in model_id: - self._extract_ai21_attributes(attributes, request_body) - elif "mistral" in model_id: - self._extract_mistral_attributes(attributes, request_body) - - except json.JSONDecodeError: - _logger.debug("Error: Unable to parse the body as JSON") - - def _extract_titan_attributes(self, attributes, request_body): - config = request_body.get("textGenerationConfig", {}) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, config.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, config.get("topP")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, config.get("maxTokenCount")) - - def _extract_nova_attributes(self, attributes, request_body): - config = request_body.get("inferenceConfig", {}) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, config.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, config.get("top_p")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, config.get("max_new_tokens")) - - def _extract_claude_attributes(self, attributes, request_body): - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, request_body.get("max_tokens")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, request_body.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, request_body.get("top_p")) - - def _extract_cohere_attributes(self, attributes, request_body): - prompt = request_body.get("message") - if prompt: - attributes[GEN_AI_USAGE_INPUT_TOKENS] = math.ceil(len(prompt) / 6) - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, request_body.get("max_tokens")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, request_body.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, request_body.get("p")) - - def _extract_ai21_attributes(self, attributes, request_body): - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, request_body.get("max_tokens")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, request_body.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, request_body.get("top_p")) - - def _extract_llama_attributes(self, attributes, request_body): - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, request_body.get("max_gen_len")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, request_body.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, request_body.get("top_p")) - - def _extract_mistral_attributes(self, attributes, request_body): - prompt = request_body.get("prompt") - if prompt: - attributes[GEN_AI_USAGE_INPUT_TOKENS] = math.ceil(len(prompt) / 6) - self._set_if_not_none(attributes, GEN_AI_REQUEST_MAX_TOKENS, request_body.get("max_tokens")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TEMPERATURE, request_body.get("temperature")) - self._set_if_not_none(attributes, GEN_AI_REQUEST_TOP_P, request_body.get("top_p")) - - @staticmethod - def _set_if_not_none(attributes, key, value): - if value is not None: - attributes[key] = value - - # pylint: disable=too-many-branches - def on_success(self, span: Span, result: Dict[str, Any]): - model_id = self._call_context.params.get(_MODEL_ID) - - if not model_id: - return - - if "body" in result and isinstance(result["body"], StreamingBody): - original_body = None - try: - original_body = result["body"] - body_content = original_body.read() - - # Use one stream for telemetry - stream = io.BytesIO(body_content) - telemetry_content = stream.read() - response_body = json.loads(telemetry_content.decode("utf-8")) - if "amazon.titan" in model_id: - self._handle_amazon_titan_response(span, response_body) - if "amazon.nova" in model_id: - self._handle_amazon_nova_response(span, response_body) - elif "anthropic.claude" in model_id: - self._handle_anthropic_claude_response(span, response_body) - elif "meta.llama" in model_id: - self._handle_meta_llama_response(span, response_body) - elif "cohere.command" in model_id: - self._handle_cohere_command_response(span, response_body) - elif "ai21.jamba" in model_id: - self._handle_ai21_jamba_response(span, response_body) - elif "mistral" in model_id: - self._handle_mistral_mistral_response(span, response_body) - # Replenish stream for downstream application use - new_stream = io.BytesIO(body_content) - result["body"] = StreamingBody(new_stream, len(body_content)) - - except json.JSONDecodeError: - _logger.debug("Error: Unable to parse the response body as JSON") - except Exception as e: # pylint: disable=broad-exception-caught, invalid-name - _logger.debug("Error processing response: %s", e) - finally: - if original_body is not None: - original_body.close() - - # pylint: disable=no-self-use - def _handle_amazon_titan_response(self, span: Span, response_body: Dict[str, Any]): - if "inputTextTokenCount" in response_body: - span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, response_body["inputTextTokenCount"]) - if "results" in response_body and response_body["results"]: - result = response_body["results"][0] - if "tokenCount" in result: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, result["tokenCount"]) - if "completionReason" in result: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [result["completionReason"]]) - - # pylint: disable=no-self-use - def _handle_amazon_nova_response(self, span: Span, response_body: Dict[str, Any]): - if "usage" in response_body: - usage = response_body["usage"] - if "inputTokens" in usage: - span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, usage["inputTokens"]) - if "outputTokens" in usage: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, usage["outputTokens"]) - if "stopReason" in response_body: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [response_body["stopReason"]]) - - # pylint: disable=no-self-use - def _handle_anthropic_claude_response(self, span: Span, response_body: Dict[str, Any]): - if "usage" in response_body: - usage = response_body["usage"] - if "input_tokens" in usage: - span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, usage["input_tokens"]) - if "output_tokens" in usage: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, usage["output_tokens"]) - if "stop_reason" in response_body: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [response_body["stop_reason"]]) - - # pylint: disable=no-self-use - def _handle_cohere_command_response(self, span: Span, response_body: Dict[str, Any]): - # Output tokens: Approximate from the response text - if "text" in response_body: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, math.ceil(len(response_body["text"]) / 6)) - if "finish_reason" in response_body: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [response_body["finish_reason"]]) - - # pylint: disable=no-self-use - def _handle_ai21_jamba_response(self, span: Span, response_body: Dict[str, Any]): - if "usage" in response_body: - usage = response_body["usage"] - if "prompt_tokens" in usage: - span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, usage["prompt_tokens"]) - if "completion_tokens" in usage: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, usage["completion_tokens"]) - if "choices" in response_body: - choices = response_body["choices"][0] - if "finish_reason" in choices: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [choices["finish_reason"]]) - - # pylint: disable=no-self-use - def _handle_meta_llama_response(self, span: Span, response_body: Dict[str, Any]): - if "prompt_token_count" in response_body: - span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, response_body["prompt_token_count"]) - if "generation_token_count" in response_body: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, response_body["generation_token_count"]) - if "stop_reason" in response_body: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [response_body["stop_reason"]]) - - # pylint: disable=no-self-use - def _handle_mistral_mistral_response(self, span: Span, response_body: Dict[str, Any]): - if "outputs" in response_body: - outputs = response_body["outputs"][0] - if "text" in outputs: - span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, math.ceil(len(outputs["text"]) / 6)) - if "stop_reason" in outputs: - span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [outputs["stop_reason"]]) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py index 0f4a77d1e..ffc81b348 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py @@ -19,13 +19,17 @@ _BedrockAgentExtension, _BedrockAgentRuntimeExtension, _BedrockExtension, - _BedrockRuntimeExtension, ) from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS from opentelemetry.instrumentation.botocore.extensions.lmbd import _LambdaExtension from opentelemetry.instrumentation.botocore.extensions.sns import _SnsExtension from opentelemetry.instrumentation.botocore.extensions.sqs import _SqsExtension -from opentelemetry.instrumentation.botocore.extensions.types import _AttributeMapT, _AwsSdkExtension, _BotoResultT +from opentelemetry.instrumentation.botocore.extensions.types import ( + _AttributeMapT, + _AwsSdkExtension, + _BotocoreInstrumentorContext, + _BotoResultT, +) from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace.span import Span @@ -75,8 +79,8 @@ def patch_extract_attributes(self, attributes: _AttributeMapT): old_on_success = _LambdaExtension.on_success - def patch_on_success(self, span: Span, result: _BotoResultT): - old_on_success(self, span, result) + def patch_on_success(self, span: Span, result: _BotoResultT, instrumentor_context: _BotocoreInstrumentorContext): + old_on_success(self, span, result, instrumentor_context) lambda_configuration = result.get("Configuration", {}) function_arn = lambda_configuration.get("FunctionArn") if function_arn: @@ -180,8 +184,8 @@ def patch_extract_attributes(self, attributes: _AttributeMapT): old_on_success = _SqsExtension.on_success - def patch_on_success(self, span: Span, result: _BotoResultT): - old_on_success(self, span, result) + def patch_on_success(self, span: Span, result: _BotoResultT, instrumentor_context: _BotocoreInstrumentorContext): + old_on_success(self, span, result, instrumentor_context) queue_url = result.get("QueueUrl") if queue_url: span.set_attribute(AWS_SQS_QUEUE_URL, queue_url) @@ -191,17 +195,17 @@ def patch_on_success(self, span: Span, result: _BotoResultT): def _apply_botocore_bedrock_patch() -> None: - """Botocore instrumentation patch for Bedrock, Bedrock Agent, Bedrock Runtime and Bedrock Agent Runtime + """Botocore instrumentation patch for Bedrock, Bedrock Agent, and Bedrock Agent Runtime This patch adds an extension to the upstream's list of known extension for Bedrock. Extensions allow for custom logic for adding service-specific information to spans, such as attributes. - Specifically, we are adding logic to add the AWS_BEDROCK attributes referenced in _aws_attribute_keys, - GEN_AI_REQUEST_MODEL and GEN_AI_SYSTEM attributes referenced in _aws_span_processing_util. + Specifically, we are adding logic to add the AWS_BEDROCK attributes referenced in _aws_attribute_keys. + Note: Bedrock Runtime uses the upstream extension directly. """ _KNOWN_EXTENSIONS["bedrock"] = _lazy_load(".", "_BedrockExtension") _KNOWN_EXTENSIONS["bedrock-agent"] = _lazy_load(".", "_BedrockAgentExtension") _KNOWN_EXTENSIONS["bedrock-agent-runtime"] = _lazy_load(".", "_BedrockAgentRuntimeExtension") - _KNOWN_EXTENSIONS["bedrock-runtime"] = _lazy_load(".", "_BedrockRuntimeExtension") + # bedrock-runtime is handled by upstream # The OpenTelemetry Authors code @@ -243,7 +247,7 @@ def extract_attributes(self, attributes: _AttributeMapT): attributes[AWS_SECRETSMANAGER_SECRET_ARN] = secret_id # pylint: disable=no-self-use - def on_success(self, span: Span, result: _BotoResultT): + def on_success(self, span: Span, result: _BotoResultT, instrumentor_context: _BotocoreInstrumentorContext): secret_arn = result.get("ARN") if secret_arn: span.set_attribute(AWS_SECRETSMANAGER_SECRET_ARN, secret_arn) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_aws_cloudwatch_emf_exporter.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_aws_cloudwatch_emf_exporter.py new file mode 100644 index 000000000..9a90c56a5 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_aws_cloudwatch_emf_exporter.py @@ -0,0 +1,625 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import time +import unittest +from unittest.mock import Mock, patch + +from amazon.opentelemetry.distro.exporter.aws.metrics.aws_cloudwatch_emf_exporter import AwsCloudWatchEmfExporter +from opentelemetry.sdk.metrics.export import Gauge, Histogram, MetricExportResult, Sum +from opentelemetry.sdk.resources import Resource + + +class MockDataPoint: + """Mock datapoint for testing.""" + + def __init__(self, value=10.0, attributes=None, time_unix_nano=None): + self.value = value + self.attributes = attributes or {} + self.time_unix_nano = time_unix_nano or int(time.time() * 1_000_000_000) + + +class MockMetric: + """Mock metric for testing.""" + + def __init__(self, name="test_metric", unit="1", description="Test metric"): + self.name = name + self.unit = unit + self.description = description + + +class MockHistogramDataPoint(MockDataPoint): + """Mock histogram datapoint for testing.""" + + def __init__(self, count=5, sum_val=25.0, min_val=1.0, max_val=10.0, **kwargs): + super().__init__(**kwargs) + self.count = count + self.sum = sum_val + self.min = min_val + self.max = max_val + + +class MockExpHistogramDataPoint(MockDataPoint): + """Mock exponential histogram datapoint for testing.""" + + def __init__(self, count=10, sum_val=50.0, min_val=1.0, max_val=20.0, scale=2, **kwargs): + super().__init__(**kwargs) + self.count = count + self.sum = sum_val + self.min = min_val + self.max = max_val + self.scale = scale + + # Mock positive buckets + self.positive = Mock() + self.positive.offset = 0 + self.positive.bucket_counts = [1, 2, 3, 4] + + # Mock negative buckets + self.negative = Mock() + self.negative.offset = 0 + self.negative.bucket_counts = [] + + # Mock zero count + self.zero_count = 0 + + +class MockGaugeData: + """Mock gauge data that passes isinstance checks.""" + + def __init__(self, data_points=None): + self.data_points = data_points or [] + + +class MockSumData: + """Mock sum data that passes isinstance checks.""" + + def __init__(self, data_points=None): + self.data_points = data_points or [] + + +class MockHistogramData: + """Mock histogram data that passes isinstance checks.""" + + def __init__(self, data_points=None): + self.data_points = data_points or [] + + +class MockExpHistogramData: + """Mock exponential histogram data that passes isinstance checks.""" + + def __init__(self, data_points=None): + self.data_points = data_points or [] + + +class MockMetricWithData: + """Mock metric with data attribute.""" + + def __init__(self, name="test_metric", unit="1", description="Test metric", data=None): + self.name = name + self.unit = unit + self.description = description + self.data = data or MockGaugeData() + + +class MockResourceMetrics: + """Mock resource metrics for testing.""" + + def __init__(self, resource=None, scope_metrics=None): + self.resource = resource or Resource.create({"service.name": "test-service"}) + self.scope_metrics = scope_metrics or [] + + +class MockScopeMetrics: + """Mock scope metrics for testing.""" + + def __init__(self, scope=None, metrics=None): + self.scope = scope or Mock() + self.metrics = metrics or [] + + +# pylint: disable=too-many-public-methods +class TestAwsCloudWatchEmfExporter(unittest.TestCase): + """Test AwsCloudWatchEmfExporter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the botocore session to avoid AWS calls + with patch("botocore.session.Session") as mock_session: + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + + self.exporter = AwsCloudWatchEmfExporter( + session=mock_session, namespace="TestNamespace", log_group_name="test-log-group" + ) + + def test_initialization(self): + """Test exporter initialization.""" + self.assertEqual(self.exporter.namespace, "TestNamespace") + self.assertEqual(self.exporter.log_group_name, "test-log-group") + self.assertIsNotNone(self.exporter.log_client) + + @patch("botocore.session.Session") + def test_initialization_with_custom_params(self, mock_session): + """Test exporter initialization with custom parameters.""" + # Mock the botocore session to avoid AWS calls + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + + exporter = AwsCloudWatchEmfExporter( + session=mock_session_instance, + namespace="CustomNamespace", + log_group_name="custom-log-group", + log_stream_name="custom-stream", + aws_region="us-west-2", + ) + self.assertEqual(exporter.namespace, "CustomNamespace") + self.assertEqual(exporter.log_group_name, "custom-log-group") + + def test_get_unit_mapping(self): + """Test unit mapping functionality.""" + # Test known units from UNIT_MAPPING + self.assertEqual( + self.exporter._get_unit(self.exporter._create_metric_record("test", "ms", "test")), "Milliseconds" + ) + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "s", "test")), "Seconds") + self.assertEqual( + self.exporter._get_unit(self.exporter._create_metric_record("test", "us", "test")), "Microseconds" + ) + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "By", "test")), "Bytes") + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "bit", "test")), "Bits") + + # Test units that map to empty string (should return empty string from mapping) + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "1", "test")), "") + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "ns", "test")), "") + + # Test EMF supported units directly (should return as-is) + self.assertEqual(self.exporter._get_unit(self.exporter._create_metric_record("test", "Count", "test")), "Count") + self.assertEqual( + self.exporter._get_unit(self.exporter._create_metric_record("test", "Percent", "test")), "Percent" + ) + self.assertEqual( + self.exporter._get_unit(self.exporter._create_metric_record("test", "Kilobytes", "test")), "Kilobytes" + ) + + # Test unknown unit (not in mapping and not in supported units, returns None) + self.assertIsNone(self.exporter._get_unit(self.exporter._create_metric_record("test", "unknown", "test"))) + + # Test empty unit (should return None due to falsy check) + self.assertIsNone(self.exporter._get_unit(self.exporter._create_metric_record("test", "", "test"))) + + # Test None unit + self.assertIsNone(self.exporter._get_unit(self.exporter._create_metric_record("test", None, "test"))) + + def test_get_metric_name(self): + """Test metric name extraction.""" + # Test with record that has name attribute + record = Mock() + record.name = "test_metric" + + result = self.exporter._get_metric_name(record) + self.assertEqual(result, "test_metric") + + # Test with record that has empty name (should return None) + record_empty = Mock() + record_empty.name = "" + + result_empty = self.exporter._get_metric_name(record_empty) + self.assertIsNone(result_empty) + + def test_get_dimension_names(self): + """Test dimension names extraction.""" + attributes = {"service.name": "test-service", "env": "prod", "region": "us-east-1"} + + result = self.exporter._get_dimension_names(attributes) + + # Should return all attribute keys + self.assertEqual(set(result), {"service.name", "env", "region"}) + + def test_get_attributes_key(self): + """Test attributes key generation.""" + attributes = {"service": "test", "env": "prod"} + + result = self.exporter._get_attributes_key(attributes) + + # Should be a string representation of sorted attributes + self.assertIsInstance(result, str) + self.assertIn("service", result) + self.assertIn("test", result) + self.assertIn("env", result) + self.assertIn("prod", result) + + def test_get_attributes_key_consistent(self): + """Test that attributes key generation is consistent.""" + # Same attributes in different order should produce same key + attrs1 = {"b": "2", "a": "1"} + attrs2 = {"a": "1", "b": "2"} + + key1 = self.exporter._get_attributes_key(attrs1) + key2 = self.exporter._get_attributes_key(attrs2) + + self.assertEqual(key1, key2) + + def test_group_by_attributes_and_timestamp(self): + """Test grouping by attributes and timestamp.""" + record = Mock() + record.attributes = {"env": "test"} + record.timestamp = 1234567890 + + result = self.exporter._group_by_attributes_and_timestamp(record) + + # Should return a tuple with attributes key and timestamp + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + self.assertEqual(result[1], 1234567890) + + def test_normalize_timestamp(self): + """Test timestamp normalization.""" + timestamp_ns = 1609459200000000000 # 2021-01-01 00:00:00 in nanoseconds + expected_ms = 1609459200000 # Same time in milliseconds + + result = self.exporter._normalize_timestamp(timestamp_ns) + self.assertEqual(result, expected_ms) + + def test_create_metric_record(self): + """Test metric record creation.""" + record = self.exporter._create_metric_record("test_metric", "Count", "Test description") + + self.assertIsNotNone(record) + self.assertEqual(record.name, "test_metric") + self.assertEqual(record.unit, "Count") + self.assertEqual(record.description, "Test description") + + def test_convert_gauge(self): + """Test gauge conversion.""" + metric = MockMetric("gauge_metric", "Count", "Gauge description") + dp = MockDataPoint(value=42.5, attributes={"key": "value"}) + + record = self.exporter._convert_gauge_and_sum(metric, dp) + + self.assertIsNotNone(record) + self.assertEqual(record.name, "gauge_metric") + self.assertEqual(record.value, 42.5) + self.assertEqual(record.attributes, {"key": "value"}) + self.assertIsInstance(record.timestamp, int) + + def test_convert_sum(self): + """Test sum conversion.""" + metric = MockMetric("sum_metric", "Count", "Sum description") + dp = MockDataPoint(value=100.0, attributes={"env": "test"}) + + record = self.exporter._convert_gauge_and_sum(metric, dp) + + self.assertIsNotNone(record) + self.assertEqual(record.name, "sum_metric") + self.assertEqual(record.value, 100.0) + self.assertEqual(record.attributes, {"env": "test"}) + self.assertIsInstance(record.timestamp, int) + + def test_convert_histogram(self): + """Test histogram conversion.""" + metric = MockMetric("histogram_metric", "ms", "Histogram description") + dp = MockHistogramDataPoint( + count=10, sum_val=150.0, min_val=5.0, max_val=25.0, attributes={"region": "us-east-1"} + ) + + record = self.exporter._convert_histogram(metric, dp) + + self.assertIsNotNone(record) + self.assertEqual(record.name, "histogram_metric") + self.assertTrue(hasattr(record, "histogram_data")) + + expected_value = {"Count": 10, "Sum": 150.0, "Min": 5.0, "Max": 25.0} + self.assertEqual(record.histogram_data, expected_value) + self.assertEqual(record.attributes, {"region": "us-east-1"}) + self.assertIsInstance(record.timestamp, int) + + def test_convert_exp_histogram(self): + """Test exponential histogram conversion.""" + metric = MockMetric("exp_histogram_metric", "s", "Exponential histogram description") + dp = MockExpHistogramDataPoint(count=8, sum_val=64.0, min_val=2.0, max_val=32.0, attributes={"service": "api"}) + + record = self.exporter._convert_exp_histogram(metric, dp) + + self.assertIsNotNone(record) + self.assertEqual(record.name, "exp_histogram_metric") + self.assertTrue(hasattr(record, "exp_histogram_data")) + + exp_data = record.exp_histogram_data + self.assertIn("Values", exp_data) + self.assertIn("Counts", exp_data) + self.assertEqual(exp_data["Count"], 8) + self.assertEqual(exp_data["Sum"], 64.0) + self.assertEqual(exp_data["Min"], 2.0) + self.assertEqual(exp_data["Max"], 32.0) + self.assertEqual(record.attributes, {"service": "api"}) + self.assertIsInstance(record.timestamp, int) + + def test_create_emf_log(self): + """Test EMF log creation.""" + # Create test records + gauge_record = self.exporter._create_metric_record("gauge_metric", "Count", "Gauge") + gauge_record.value = 50.0 + gauge_record.timestamp = int(time.time() * 1000) + gauge_record.attributes = {"env": "test"} + + records = [gauge_record] + resource = Resource.create({"service.name": "test-service"}) + + result = self.exporter._create_emf_log(records, resource) + + self.assertIsInstance(result, dict) + + # Check that the result is JSON serializable + json.dumps(result) # Should not raise exception + + def test_create_emf_log_with_resource(self): + """Test EMF log creation with resource attributes.""" + # Create test records + gauge_record = self.exporter._create_metric_record("gauge_metric", "Count", "Gauge") + gauge_record.value = 50.0 + gauge_record.timestamp = int(time.time() * 1000) + gauge_record.attributes = {"env": "test", "service": "api"} + + records = [gauge_record] + resource = Resource.create({"service.name": "test-service", "service.version": "1.0.0"}) + + result = self.exporter._create_emf_log(records, resource, 1234567890) + + # Verify EMF log structure + self.assertIn("_aws", result) + self.assertIn("CloudWatchMetrics", result["_aws"]) + self.assertEqual(result["_aws"]["Timestamp"], 1234567890) + self.assertEqual(result["Version"], "1") + + # Check resource attributes are prefixed + self.assertEqual(result["otel.resource.service.name"], "test-service") + self.assertEqual(result["otel.resource.service.version"], "1.0.0") + + # Check metric attributes + self.assertEqual(result["env"], "test") + self.assertEqual(result["service"], "api") + + # Check metric value + self.assertEqual(result["gauge_metric"], 50.0) + + # Check CloudWatch metrics structure + cw_metrics = result["_aws"]["CloudWatchMetrics"][0] + self.assertEqual(cw_metrics["Namespace"], "TestNamespace") + self.assertEqual(set(cw_metrics["Dimensions"][0]), {"env", "service"}) + self.assertEqual(cw_metrics["Metrics"][0]["Name"], "gauge_metric") + + def test_create_emf_log_skips_empty_metric_names(self): + """Test that EMF log creation skips records with empty metric names.""" + # Create a record with no metric name + record_without_name = Mock() + record_without_name.attributes = {"key": "value"} + record_without_name.value = 10.0 + record_without_name.name = None # No valid name + + # Create a record with valid metric name + valid_record = self.exporter._create_metric_record("valid_metric", "Count", "Valid metric") + valid_record.value = 20.0 + valid_record.attributes = {"key": "value"} + + records = [record_without_name, valid_record] + resource = Resource.create({"service.name": "test-service"}) + + result = self.exporter._create_emf_log(records, resource, 1234567890) + + # Only the valid record should be processed + self.assertIn("valid_metric", result) + self.assertEqual(result["valid_metric"], 20.0) + + # Check that only the valid metric is in the definitions (empty names are skipped) + cw_metrics = result["_aws"]["CloudWatchMetrics"][0] + self.assertEqual(len(cw_metrics["Metrics"]), 1) + # Ensure our valid metric is present + metric_names = [m["Name"] for m in cw_metrics["Metrics"]] + self.assertIn("valid_metric", metric_names) + + @patch("botocore.session.Session") + def test_export_success(self, mock_session): + """Test successful export.""" + # Mock CloudWatch Logs client + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + mock_client.put_log_events.return_value = {"nextSequenceToken": "12345"} + + # Create empty metrics data to test basic export flow + metrics_data = Mock() + metrics_data.resource_metrics = [] + + result = self.exporter.export(metrics_data) + + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_export_failure(self): + """Test export failure handling.""" + # Create metrics data that will cause an exception during iteration + metrics_data = Mock() + # Make resource_metrics raise an exception when iterated over + metrics_data.resource_metrics = Mock() + metrics_data.resource_metrics.__iter__ = Mock(side_effect=Exception("Test exception")) + + result = self.exporter.export(metrics_data) + + self.assertEqual(result, MetricExportResult.FAILURE) + + def test_export_with_gauge_metrics(self): + """Test exporting actual gauge metrics.""" + # Create mock metrics data + resource = Resource.create({"service.name": "test-service"}) + + # Create gauge data + gauge_data = Gauge(data_points=[MockDataPoint(value=42.0, attributes={"key": "value"})]) + + metric = MockMetricWithData(name="test_gauge", data=gauge_data) + + scope_metrics = MockScopeMetrics(metrics=[metric]) + resource_metrics = MockResourceMetrics(resource=resource, scope_metrics=[scope_metrics]) + + metrics_data = Mock() + metrics_data.resource_metrics = [resource_metrics] + + result = self.exporter.export(metrics_data) + + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_export_with_sum_metrics(self): + """Test export with Sum metrics.""" + # Create mock metrics data with Sum type + resource = Resource.create({"service.name": "test-service"}) + + sum_data = MockSumData([MockDataPoint(value=25.0, attributes={"env": "test"})]) + # Create a mock that will pass the type() check for Sum + sum_data.__class__ = Sum + metric = MockMetricWithData(name="test_sum", data=sum_data) + + scope_metrics = MockScopeMetrics(metrics=[metric]) + resource_metrics = MockResourceMetrics(resource=resource, scope_metrics=[scope_metrics]) + + metrics_data = Mock() + metrics_data.resource_metrics = [resource_metrics] + + result = self.exporter.export(metrics_data) + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_export_with_histogram_metrics(self): + """Test export with Histogram metrics.""" + # Create mock metrics data with Histogram type + resource = Resource.create({"service.name": "test-service"}) + + hist_dp = MockHistogramDataPoint(count=5, sum_val=25.0, min_val=1.0, max_val=10.0, attributes={"env": "test"}) + hist_data = MockHistogramData([hist_dp]) + # Create a mock that will pass the type() check for Histogram + hist_data.__class__ = Histogram + metric = MockMetricWithData(name="test_histogram", data=hist_data) + + scope_metrics = MockScopeMetrics(metrics=[metric]) + resource_metrics = MockResourceMetrics(resource=resource, scope_metrics=[scope_metrics]) + + metrics_data = Mock() + metrics_data.resource_metrics = [resource_metrics] + + result = self.exporter.export(metrics_data) + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_export_with_unsupported_metric_type(self): + """Test export with unsupported metric types.""" + # Create mock metrics data with unsupported metric type + resource = Resource.create({"service.name": "test-service"}) + + # Create non-gauge data + unsupported_data = Mock() + unsupported_data.data_points = [MockDataPoint(value=42.0)] + + metric = MockMetricWithData(name="test_counter", data=unsupported_data) + + scope_metrics = MockScopeMetrics(metrics=[metric]) + resource_metrics = MockResourceMetrics(resource=resource, scope_metrics=[scope_metrics]) + + metrics_data = Mock() + metrics_data.resource_metrics = [resource_metrics] + + # Should still return success even with unsupported metrics + result = self.exporter.export(metrics_data) + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_export_with_metric_without_data(self): + """Test export with metrics that don't have data attribute.""" + # Create mock metrics data + resource = Resource.create({"service.name": "test-service"}) + + # Create metric without data attribute + metric = Mock(spec=[]) + + scope_metrics = MockScopeMetrics(metrics=[metric]) + resource_metrics = MockResourceMetrics(resource=resource, scope_metrics=[scope_metrics]) + + metrics_data = Mock() + metrics_data.resource_metrics = [resource_metrics] + + # Should still return success + result = self.exporter.export(metrics_data) + self.assertEqual(result, MetricExportResult.SUCCESS) + + def test_get_metric_name_fallback(self): + """Test metric name extraction fallback.""" + # Test with record that has no instrument attribute + record = Mock(spec=[]) + + result = self.exporter._get_metric_name(record) + self.assertIsNone(result) + + def test_get_metric_name_empty_name(self): + """Test metric name extraction with empty name.""" + # Test with record that has empty name + record = Mock() + record.name = "" + + result = self.exporter._get_metric_name(record) + self.assertIsNone(result) + + @patch("os.environ.get") + @patch("botocore.session.Session") + def test_initialization_with_env_region(self, mock_session, mock_env_get): + """Test initialization with AWS region from environment.""" + # Mock environment variable + mock_env_get.side_effect = lambda key: "us-west-1" if key == "AWS_REGION" else None + + # Mock the botocore session to avoid AWS calls + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + + exporter = AwsCloudWatchEmfExporter( + session=mock_session, namespace="TestNamespace", log_group_name="test-log-group" + ) + + # Just verify the exporter was created successfully with region handling + self.assertIsNotNone(exporter) + self.assertEqual(exporter.namespace, "TestNamespace") + + def test_force_flush_no_pending_events(self): + """Test force flush functionality with no pending events.""" + result = self.exporter.force_flush() + + self.assertTrue(result) + + @patch.object(AwsCloudWatchEmfExporter, "force_flush") + def test_shutdown(self, mock_force_flush): + """Test shutdown functionality.""" + mock_force_flush.return_value = True + + result = self.exporter.shutdown(timeout_millis=5000) + + self.assertTrue(result) + mock_force_flush.assert_called_once_with(5000) + + # pylint: disable=broad-exception-caught + def test_send_log_event_method_exists(self): + """Test that _send_log_event method exists and can be called.""" + # Just test that the method exists and doesn't crash with basic input + log_event = {"message": "test message", "timestamp": 1234567890} + + # Mock the log client to avoid actual AWS calls + with patch.object(self.exporter.log_client, "send_log_event") as mock_send: + # Should not raise an exception + try: + self.exporter._send_log_event(log_event) + mock_send.assert_called_once_with(log_event) + except Exception as error: + self.fail(f"_send_log_event raised an exception: {error}") + + +if __name__ == "__main__": + unittest.main() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_cloudwatch_log_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_cloudwatch_log_client.py new file mode 100644 index 000000000..2793aeb34 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/aws/metrics/test_cloudwatch_log_client.py @@ -0,0 +1,584 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-public-methods + +import time +import unittest +from unittest.mock import Mock, patch + +from botocore.exceptions import ClientError + +from amazon.opentelemetry.distro.exporter.aws.metrics._cloudwatch_log_client import CloudWatchLogClient + + +class TestCloudWatchLogClient(unittest.TestCase): + """Test CloudWatchLogClient class.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the botocore session to avoid AWS calls + with patch("botocore.session.Session") as mock_session: + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + + self.log_client = CloudWatchLogClient(session=mock_session, log_group_name="test-log-group") + + def test_initialization(self): + """Test log client initialization.""" + self.assertEqual(self.log_client.log_group_name, "test-log-group") + self.assertIsNotNone(self.log_client.log_stream_name) + self.assertTrue(self.log_client.log_stream_name.startswith("otel-python-")) + + @patch("botocore.session.Session") + def test_initialization_with_custom_params(self, mock_session): + """Test log client initialization with custom parameters.""" + # Mock the botocore session to avoid AWS calls + mock_client = Mock() + mock_session_instance = Mock() + mock_session.return_value = mock_session_instance + mock_session_instance.create_client.return_value = mock_client + + log_client = CloudWatchLogClient( + session=mock_session, + log_group_name="custom-log-group", + log_stream_name="custom-stream", + aws_region="us-west-2", + ) + self.assertEqual(log_client.log_group_name, "custom-log-group") + self.assertEqual(log_client.log_stream_name, "custom-stream") + + def test_generate_log_stream_name(self): + """Test log stream name generation.""" + name1 = self.log_client._generate_log_stream_name() + name2 = self.log_client._generate_log_stream_name() + + # Should generate unique names + self.assertNotEqual(name1, name2) + self.assertTrue(name1.startswith("otel-python-")) + self.assertTrue(name2.startswith("otel-python-")) + + def test_create_log_group_if_needed_success(self): + """Test log group creation when needed.""" + # This method should not raise an exception + self.log_client._create_log_group_if_needed() + + def test_create_log_group_if_needed_already_exists(self): + """Test log group creation when it already exists.""" + # Mock the create_log_group to raise ResourceAlreadyExistsException + self.log_client.logs_client.create_log_group.side_effect = ClientError( + {"Error": {"Code": "ResourceAlreadyExistsException"}}, "CreateLogGroup" + ) + + # This should not raise an exception + self.log_client._create_log_group_if_needed() + + def test_create_log_group_if_needed_failure(self): + """Test log group creation failure.""" + # Mock the create_log_group to raise AccessDenied error + self.log_client.logs_client.create_log_group.side_effect = ClientError( + {"Error": {"Code": "AccessDenied"}}, "CreateLogGroup" + ) + + with self.assertRaises(ClientError): + self.log_client._create_log_group_if_needed() + + def test_create_event_batch(self): + """Test event batch creation.""" + batch = self.log_client._create_event_batch() + + self.assertEqual(batch.log_events, []) + self.assertEqual(batch.byte_total, 0) + self.assertEqual(batch.min_timestamp_ms, 0) + self.assertEqual(batch.max_timestamp_ms, 0) + self.assertIsInstance(batch.created_timestamp_ms, int) + + def test_validate_log_event_valid(self): + """Test log event validation with valid event.""" + log_event = {"message": "test message", "timestamp": int(time.time() * 1000)} + + result = self.log_client._validate_log_event(log_event) + self.assertTrue(result) + + def test_validate_log_event_empty_message(self): + """Test log event validation with empty message.""" + log_event = {"message": "", "timestamp": int(time.time() * 1000)} + + result = self.log_client._validate_log_event(log_event) + self.assertFalse(result) + + # Test whitespace-only message + whitespace_event = {"message": " ", "timestamp": int(time.time() * 1000)} + result = self.log_client._validate_log_event(whitespace_event) + self.assertFalse(result) + + # Test missing message key + missing_message_event = {"timestamp": int(time.time() * 1000)} + result = self.log_client._validate_log_event(missing_message_event) + self.assertFalse(result) + + def test_validate_log_event_oversized_message(self): + """Test log event validation with oversized message.""" + # Create a message larger than the maximum allowed size + large_message = "x" * (self.log_client.CW_MAX_EVENT_PAYLOAD_BYTES + 100) + log_event = {"message": large_message, "timestamp": int(time.time() * 1000)} + + result = self.log_client._validate_log_event(log_event) + self.assertTrue(result) # Should still be valid after truncation + # Check that message was truncated + self.assertLess(len(log_event["message"]), len(large_message)) + self.assertTrue(log_event["message"].endswith(self.log_client.CW_TRUNCATED_SUFFIX)) + + def test_validate_log_event_old_timestamp(self): + """Test log event validation with very old timestamp.""" + # Timestamp from 15 days ago + old_timestamp = int(time.time() * 1000) - (15 * 24 * 60 * 60 * 1000) + log_event = {"message": "test message", "timestamp": old_timestamp} + + result = self.log_client._validate_log_event(log_event) + self.assertFalse(result) + + def test_validate_log_event_future_timestamp(self): + """Test log event validation with future timestamp.""" + # Timestamp 3 hours in the future + future_timestamp = int(time.time() * 1000) + (3 * 60 * 60 * 1000) + log_event = {"message": "test message", "timestamp": future_timestamp} + + result = self.log_client._validate_log_event(log_event) + self.assertFalse(result) + + def test_event_batch_exceeds_limit_by_count(self): + """Test batch limit checking by event count.""" + batch = self.log_client._create_event_batch() + # Simulate batch with maximum events + for _ in range(self.log_client.CW_MAX_REQUEST_EVENT_COUNT): + batch.add_event({"message": "test", "timestamp": int(time.time() * 1000)}, 10) + + result = self.log_client._event_batch_exceeds_limit(batch, 100) + self.assertTrue(result) + + def test_event_batch_exceeds_limit_by_size(self): + """Test batch limit checking by byte size.""" + batch = self.log_client._create_event_batch() + # Manually set byte_total to near limit + batch.byte_total = self.log_client.CW_MAX_REQUEST_PAYLOAD_BYTES - 50 + + result = self.log_client._event_batch_exceeds_limit(batch, 100) + self.assertTrue(result) + + def test_event_batch_within_limits(self): + """Test batch limit checking within limits.""" + batch = self.log_client._create_event_batch() + for _ in range(10): + batch.add_event({"message": "test", "timestamp": int(time.time() * 1000)}, 100) + + result = self.log_client._event_batch_exceeds_limit(batch, 100) + self.assertFalse(result) + + def test_is_batch_active_new_batch(self): + """Test batch activity check for new batch.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + + result = self.log_client._is_batch_active(batch, current_time) + self.assertTrue(result) + + def test_is_batch_active_24_hour_span(self): + """Test batch activity check for 24+ hour span.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + # Add an event to set the timestamps + batch.add_event({"message": "test", "timestamp": current_time}, 10) + + # Test with timestamp 25 hours in the future + future_timestamp = current_time + (25 * 60 * 60 * 1000) + + result = self.log_client._is_batch_active(batch, future_timestamp) + self.assertFalse(result) + + def test_log_event_batch_add_event(self): + """Test adding log event to batch.""" + batch = self.log_client._create_event_batch() + log_event = {"message": "test message", "timestamp": int(time.time() * 1000)} + event_size = 100 + + batch.add_event(log_event, event_size) + + self.assertEqual(batch.size(), 1) + self.assertEqual(batch.byte_total, event_size) + self.assertEqual(batch.min_timestamp_ms, log_event["timestamp"]) + self.assertEqual(batch.max_timestamp_ms, log_event["timestamp"]) + + def test_sort_log_events(self): + """Test sorting log events by timestamp.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + + # Add events with timestamps in reverse order + events = [ + {"message": "third", "timestamp": current_time + 2000}, + {"message": "first", "timestamp": current_time}, + {"message": "second", "timestamp": current_time + 1000}, + ] + + # Add events to batch in unsorted order + for event in events: + batch.add_event(event, 10) + + self.log_client._sort_log_events(batch) + + # Check that events are now sorted by timestamp + self.assertEqual(batch.log_events[0]["message"], "first") + self.assertEqual(batch.log_events[1]["message"], "second") + self.assertEqual(batch.log_events[2]["message"], "third") + + @patch.object(CloudWatchLogClient, "_send_log_batch") + def test_flush_pending_events_with_pending_events(self, mock_send_batch): + """Test flush pending events functionality with pending events.""" + # Create a batch with events + self.log_client._event_batch = self.log_client._create_event_batch() + self.log_client._event_batch.add_event({"message": "test", "timestamp": int(time.time() * 1000)}, 10) + + result = self.log_client.flush_pending_events() + + self.assertTrue(result) + mock_send_batch.assert_called_once() + + def test_flush_pending_events_no_pending_events(self): + """Test flush pending events functionality with no pending events.""" + # No batch exists + self.assertIsNone(self.log_client._event_batch) + + result = self.log_client.flush_pending_events() + + self.assertTrue(result) + + def test_send_log_event_method_exists(self): + """Test that send_log_event method exists and can be called.""" + # Just test that the method exists and doesn't crash with basic input + log_event = {"message": "test message", "timestamp": 1234567890} + + # Mock the AWS client methods to avoid actual AWS calls + with patch.object(self.log_client.logs_client, "put_log_events") as mock_put: + mock_put.return_value = {"nextSequenceToken": "12345"} + + # Should not raise an exception + try: + self.log_client.send_log_event(log_event) + # Method should complete without error + except ClientError as error: + self.fail(f"send_log_event raised an exception: {error}") + + def test_send_log_batch_with_resource_not_found(self): + """Test lazy creation when put_log_events fails with ResourceNotFoundException.""" + batch = self.log_client._create_event_batch() + batch.add_event({"message": "test message", "timestamp": int(time.time() * 1000)}, 10) + + # Mock put_log_events to fail first, then succeed + mock_put = self.log_client.logs_client.put_log_events + mock_put.side_effect = [ + ClientError({"Error": {"Code": "ResourceNotFoundException"}}, "PutLogEvents"), + {"nextSequenceToken": "12345"}, + ] + + # Mock the create methods + mock_create_group = Mock() + mock_create_stream = Mock() + self.log_client._create_log_group_if_needed = mock_create_group + self.log_client._create_log_stream_if_needed = mock_create_stream + + # Should not raise an exception and should create resources + self.log_client._send_log_batch(batch) + + # Verify that creation methods were called + mock_create_group.assert_called_once() + mock_create_stream.assert_called_once() + + # Verify put_log_events was called twice (initial attempt + retry) + self.assertEqual(mock_put.call_count, 2) + + def test_send_log_batch_with_other_error(self): + """Test that non-ResourceNotFoundException errors are re-raised.""" + batch = self.log_client._create_event_batch() + batch.add_event({"message": "test message", "timestamp": int(time.time() * 1000)}, 10) + + # Mock put_log_events to fail with different error + self.log_client.logs_client.put_log_events.side_effect = ClientError( + {"Error": {"Code": "AccessDenied"}}, "PutLogEvents" + ) + + # Should raise the original exception + with self.assertRaises(ClientError): + self.log_client._send_log_batch(batch) + + def test_create_log_stream_if_needed_success(self): + """Test log stream creation when needed.""" + # This method should not raise an exception + self.log_client._create_log_stream_if_needed() + + def test_create_log_stream_if_needed_already_exists(self): + """Test log stream creation when it already exists.""" + # Mock the create_log_stream to raise ResourceAlreadyExistsException + self.log_client.logs_client.create_log_stream.side_effect = ClientError( + {"Error": {"Code": "ResourceAlreadyExistsException"}}, "CreateLogStream" + ) + + # This should not raise an exception + self.log_client._create_log_stream_if_needed() + + def test_create_log_stream_if_needed_failure(self): + """Test log stream creation failure.""" + # Mock the create_log_stream to raise AccessDenied error + self.log_client.logs_client.create_log_stream.side_effect = ClientError( + {"Error": {"Code": "AccessDenied"}}, "CreateLogStream" + ) + + with self.assertRaises(ClientError): + self.log_client._create_log_stream_if_needed() + + def test_send_log_batch_success(self): + """Test successful log batch sending.""" + batch = self.log_client._create_event_batch() + batch.add_event({"message": "test message", "timestamp": int(time.time() * 1000)}, 10) + + # Mock successful put_log_events call + self.log_client.logs_client.put_log_events.return_value = {"nextSequenceToken": "12345"} + + # Should not raise an exception + result = self.log_client._send_log_batch(batch) + self.assertEqual(result["nextSequenceToken"], "12345") + + def test_send_log_batch_empty_batch(self): + """Test sending empty batch does nothing.""" + batch = self.log_client._create_event_batch() + # Empty batch should return early without calling AWS + + result = self.log_client._send_log_batch(batch) + self.assertIsNone(result) + + # Verify put_log_events was not called + self.log_client.logs_client.put_log_events.assert_not_called() + + def test_is_batch_active_flush_interval_reached(self): + """Test batch activity check when flush interval is reached.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + + # Set the batch creation time to more than flush interval ago + batch.created_timestamp_ms = current_time - (self.log_client.BATCH_FLUSH_INTERVAL + 1000) + # Add an event to set the timestamps + batch.add_event({"message": "test", "timestamp": current_time}, 10) + + result = self.log_client._is_batch_active(batch, current_time) + self.assertFalse(result) + + def test_send_log_event_with_invalid_event(self): + """Test send_log_event with an invalid event that fails validation.""" + # Create an event that will fail validation (empty message) + log_event = {"message": "", "timestamp": int(time.time() * 1000)} + + # Should not raise an exception, but should not call put_log_events + self.log_client.send_log_event(log_event) + + # Verify put_log_events was not called due to validation failure + self.log_client.logs_client.put_log_events.assert_not_called() + + def test_send_log_event_batching_logic(self): + """Test that send_log_event properly batches events.""" + log_event = {"message": "test message", "timestamp": int(time.time() * 1000)} + + # Mock put_log_events to not be called initially (batching) + self.log_client.logs_client.put_log_events.return_value = {"nextSequenceToken": "12345"} + + # Send one event (should be batched, not sent immediately) + self.log_client.send_log_event(log_event) + + # Verify event was added to batch + self.assertIsNotNone(self.log_client._event_batch) + self.assertEqual(self.log_client._event_batch.size(), 1) + + # put_log_events should not be called yet (event is batched) + self.log_client.logs_client.put_log_events.assert_not_called() + + def test_send_log_event_force_batch_send(self): + """Test that send_log_event sends batch when limits are exceeded.""" + # Mock put_log_events + self.log_client.logs_client.put_log_events.return_value = {"nextSequenceToken": "12345"} + + # Create events to reach the maximum event count limit + current_time = int(time.time() * 1000) + + # Send events up to the limit (should all be batched) + for event_index in range(self.log_client.CW_MAX_REQUEST_EVENT_COUNT): + log_event = {"message": f"test message {event_index}", "timestamp": current_time} + self.log_client.send_log_event(log_event) + + # At this point, no batch should have been sent yet + self.log_client.logs_client.put_log_events.assert_not_called() + + # Send one more event (should trigger batch send due to count limit) + final_event = {"message": "final message", "timestamp": current_time} + self.log_client.send_log_event(final_event) + + # put_log_events should have been called once + self.log_client.logs_client.put_log_events.assert_called_once() + + def test_log_event_batch_clear(self): + """Test clearing a log event batch.""" + batch = self.log_client._create_event_batch() + batch.add_event({"message": "test", "timestamp": int(time.time() * 1000)}, 100) + + # Verify batch has content + self.assertFalse(batch.is_empty()) + self.assertEqual(batch.size(), 1) + + # Clear and verify + batch.clear() + self.assertTrue(batch.is_empty()) + self.assertEqual(batch.size(), 0) + self.assertEqual(batch.byte_total, 0) + + def test_log_event_batch_timestamp_tracking(self): + """Test timestamp tracking in LogEventBatch.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + + # Add first event + batch.add_event({"message": "first", "timestamp": current_time}, 10) + self.assertEqual(batch.min_timestamp_ms, current_time) + self.assertEqual(batch.max_timestamp_ms, current_time) + + # Add earlier event + earlier_time = current_time - 1000 + batch.add_event({"message": "earlier", "timestamp": earlier_time}, 10) + self.assertEqual(batch.min_timestamp_ms, earlier_time) + self.assertEqual(batch.max_timestamp_ms, current_time) + + # Add later event + later_time = current_time + 1000 + batch.add_event({"message": "later", "timestamp": later_time}, 10) + self.assertEqual(batch.min_timestamp_ms, earlier_time) + self.assertEqual(batch.max_timestamp_ms, later_time) + + def test_generate_log_stream_name_format(self): + """Test log stream name generation format and uniqueness.""" + name = self.log_client._generate_log_stream_name() + self.assertTrue(name.startswith("otel-python-")) + self.assertEqual(len(name), len("otel-python-") + 8) + + # Generate another and ensure they're different + name2 = self.log_client._generate_log_stream_name() + self.assertNotEqual(name, name2) + + @patch("botocore.session.Session") + def test_initialization_with_custom_log_stream_name(self, mock_session): + """Test initialization with custom log stream name.""" + # Mock the session and client + mock_client = Mock() + mock_session.return_value.create_client.return_value = mock_client + + custom_stream = "my-custom-stream" + client = CloudWatchLogClient(session=mock_session, log_group_name="test-group", log_stream_name=custom_stream) + self.assertEqual(client.log_stream_name, custom_stream) + + def test_send_log_batch_empty_batch_no_aws_call(self): + """Test sending an empty batch returns None and doesn't call AWS.""" + batch = self.log_client._create_event_batch() + result = self.log_client._send_log_batch(batch) + self.assertIsNone(result) + + # Verify put_log_events is not called for empty batch + self.log_client.logs_client.put_log_events.assert_not_called() + + def test_validate_log_event_missing_timestamp(self): + """Test validation of log event with missing timestamp.""" + log_event = {"message": "test message"} # No timestamp + result = self.log_client._validate_log_event(log_event) + + # Should be invalid - timestamp defaults to 0 which is too old + self.assertFalse(result) + + def test_validate_log_event_invalid_timestamp_past(self): + """Test validation of log event with timestamp too far in the past.""" + # Create timestamp older than 14 days + old_time = int(time.time() * 1000) - (15 * 24 * 60 * 60 * 1000) + log_event = {"message": "test message", "timestamp": old_time} + + result = self.log_client._validate_log_event(log_event) + self.assertFalse(result) + + def test_validate_log_event_invalid_timestamp_future(self): + """Test validation of log event with timestamp too far in the future.""" + # Create timestamp more than 2 hours in the future + future_time = int(time.time() * 1000) + (3 * 60 * 60 * 1000) + log_event = {"message": "test message", "timestamp": future_time} + + result = self.log_client._validate_log_event(log_event) + self.assertFalse(result) + + def test_send_log_event_validation_failure(self): + """Test send_log_event when validation fails.""" + # Create invalid event (empty message) + invalid_event = {"message": "", "timestamp": int(time.time() * 1000)} + + # Mock put_log_events to track calls + self.log_client.logs_client.put_log_events.return_value = {"nextSequenceToken": "12345"} + + # Send invalid event + self.log_client.send_log_event(invalid_event) + + # Should not call put_log_events or create batch + self.log_client.logs_client.put_log_events.assert_not_called() + self.assertIsNone(self.log_client._event_batch) + + def test_send_log_event_exception_handling(self): + """Test exception handling in send_log_event.""" + # Mock _validate_log_event to raise an exception + with patch.object(self.log_client, "_validate_log_event", side_effect=Exception("Test error")): + log_event = {"message": "test", "timestamp": int(time.time() * 1000)} + + with self.assertRaises(Exception) as context: + self.log_client.send_log_event(log_event) + + self.assertEqual(str(context.exception), "Test error") + + def test_flush_pending_events_no_batch(self): + """Test flush pending events when no batch exists.""" + # Ensure no batch exists + self.log_client._event_batch = None + + result = self.log_client.flush_pending_events() + self.assertTrue(result) + + # Should not call send_log_batch + self.log_client.logs_client.put_log_events.assert_not_called() + + def test_is_batch_active_edge_cases(self): + """Test edge cases for batch activity checking.""" + batch = self.log_client._create_event_batch() + current_time = int(time.time() * 1000) + + # Test exactly at 24 hour boundary (should still be active) + batch.add_event({"message": "test", "timestamp": current_time}, 10) + exactly_24h_future = current_time + (24 * 60 * 60 * 1000) + result = self.log_client._is_batch_active(batch, exactly_24h_future) + self.assertTrue(result) + + # Test just over 24 hour boundary (should be inactive) + over_24h_future = current_time + (24 * 60 * 60 * 1000 + 1) + result = self.log_client._is_batch_active(batch, over_24h_future) + self.assertFalse(result) + + # Test exactly at flush interval boundary + # Create a new batch for this test + batch2 = self.log_client._create_event_batch() + batch2.add_event({"message": "test", "timestamp": current_time}, 10) + batch2.created_timestamp_ms = current_time - self.log_client.BATCH_FLUSH_INTERVAL + result = self.log_client._is_batch_active(batch2, current_time) + self.assertFalse(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py index e0c62b89d..11babbb7b 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/common/test_aws_auth_session.py @@ -6,6 +6,7 @@ import requests from botocore.credentials import Credentials +from amazon.opentelemetry.distro._utils import get_aws_session from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession AWS_OTLP_TRACES_ENDPOINT = "https://xray.us-east-1.amazonaws.com/v1/traces" @@ -19,27 +20,12 @@ class TestAwsAuthSession(TestCase): - @patch("pkg_resources.get_distribution", side_effect=ImportError("test error")) - @patch.dict("sys.modules", {"botocore": None}, clear=False) - @patch("requests.Session.request", return_value=requests.Response()) - def test_aws_auth_session_no_botocore(self, _, __): - """Tests that aws_auth_session will not inject SigV4 Headers if botocore is not installed.""" - - session = AwsAuthSession("us-east-1", "xray") - actual_headers = {"test": "test"} - - session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) - - self.assertNotIn(AUTHORIZATION_HEADER, actual_headers) - self.assertNotIn(X_AMZ_DATE_HEADER, actual_headers) - self.assertNotIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) - @patch("requests.Session.request", return_value=requests.Response()) @patch("botocore.session.Session.get_credentials", return_value=None) def test_aws_auth_session_no_credentials(self, _, __): """Tests that aws_auth_session will not inject SigV4 Headers if retrieving credentials returns None.""" - session = AwsAuthSession("us-east-1", "xray") + session = AwsAuthSession("us-east-1", "xray", get_aws_session()) actual_headers = {"test": "test"} session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) @@ -53,7 +39,7 @@ def test_aws_auth_session_no_credentials(self, _, __): def test_aws_auth_session(self, _, __): """Tests that aws_auth_session will inject SigV4 Headers if botocore is installed.""" - session = AwsAuthSession("us-east-1", "xray") + session = AwsAuthSession("us-east-1", "xray", get_aws_session()) actual_headers = {"test": "test"} session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py deleted file mode 100644 index 1abf680f1..000000000 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/aws_batch_log_record_processor_test.py +++ /dev/null @@ -1,236 +0,0 @@ -import time -import unittest -from typing import List -from unittest.mock import MagicMock, patch - -from amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor import ( - AwsBatchLogRecordProcessor, - BatchLogExportStrategy, -) -from opentelemetry._logs.severity import SeverityNumber -from opentelemetry.sdk._logs import LogData, LogRecord -from opentelemetry.sdk._logs.export import LogExportResult -from opentelemetry.sdk.util.instrumentation import InstrumentationScope -from opentelemetry.trace import TraceFlags -from opentelemetry.util.types import AnyValue - - -class TestAwsBatchLogRecordProcessor(unittest.TestCase): - - def setUp(self): - self.mock_exporter = MagicMock() - self.mock_exporter.export.return_value = LogExportResult.SUCCESS - - self.processor = AwsBatchLogRecordProcessor(exporter=self.mock_exporter) - - def test_process_log_data_nested_structure(self): - """Tests that the processor correctly handles nested structures (dict/list)""" - message_size = 400 - depth = 2 - - nested_dict_log_body = self.generate_nested_log_body( - depth=depth, expected_body="X" * message_size, create_map=True - ) - nested_array_log_body = self.generate_nested_log_body( - depth=depth, expected_body="X" * message_size, create_map=False - ) - - dict_size = self.processor._get_any_value_size(val=nested_dict_log_body, depth=depth) - array_size = self.processor._get_any_value_size(val=nested_array_log_body, depth=depth) - - # Asserting almost equal to account for key lengths in the Log object body - self.assertAlmostEqual(dict_size, message_size, delta=20) - self.assertAlmostEqual(array_size, message_size, delta=20) - - def test_process_log_data_nested_structure_exceeds_depth(self): - """Tests that the processor returns 0 for nested structure that exceeds the depth limit""" - message_size = 400 - log_body = "X" * message_size - - nested_dict_log_body = self.generate_nested_log_body(depth=4, expected_body=log_body, create_map=True) - nested_array_log_body = self.generate_nested_log_body(depth=4, expected_body=log_body, create_map=False) - - dict_size = self.processor._get_any_value_size(val=nested_dict_log_body, depth=3) - array_size = self.processor._get_any_value_size(val=nested_array_log_body, depth=3) - - self.assertEqual(dict_size, 0) - self.assertEqual(array_size, 0) - - def test_process_log_data_nested_structure_size_exceeds_max_log_size(self): - """Tests that the processor returns prematurely if the size already exceeds _MAX_LOG_REQUEST_BYTE_SIZE""" - log_body = { - "smallKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE // 2), - "bigKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1), - } - - nested_dict_log_body = self.generate_nested_log_body(depth=0, expected_body=log_body, create_map=True) - nested_array_log_body = self.generate_nested_log_body(depth=0, expected_body=log_body, create_map=False) - - dict_size = self.processor._get_any_value_size(val=nested_dict_log_body) - array_size = self.processor._get_any_value_size(val=nested_array_log_body) - - self.assertAlmostEqual(dict_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE, delta=20) - self.assertAlmostEqual(array_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE, delta=20) - - def test_process_log_data_primitive(self): - - primitives: List[AnyValue] = ["test", b"test", 1, 1.2, True, False, None] - expected_sizes = [4, 4, 1, 3, 4, 5, 0] - - for i in range(len(primitives)): - body = primitives[i] - expected_size = expected_sizes[i] - - actual_size = self.processor._get_any_value_size(body) - self.assertEqual(actual_size, expected_size) - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", - return_value=MagicMock(), - ) - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") - def test_export_single_batch_under_size_limit(self, _, __, ___): - """Tests that export is only called once if a single batch is under the size limit""" - log_count = 10 - log_body = "test" - test_logs = self.generate_test_log_data(count=log_count, log_body=log_body) - total_data_size = 0 - - for log in test_logs: - size = self.processor._get_any_value_size(log.log_record.body) - total_data_size += size - self.processor._queue.appendleft(log) - - self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) - args, _ = self.mock_exporter.export.call_args - actual_batch = args[0] - - self.assertLess(total_data_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE) - self.assertEqual(len(self.processor._queue), 0) - self.assertEqual(len(actual_batch), log_count) - self.mock_exporter.export.assert_called_once() - self.mock_exporter.set_gen_ai_log_flag.assert_not_called() - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", - return_value=MagicMock(), - ) - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") - def test_export_single_batch_all_logs_over_size_limit(self, _, __, ___): - """Should make multiple export calls of batch size 1 to export logs of size > 1 MB. - But should only call set_gen_ai_log_flag if it's a Gen AI log event.""" - - large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) - non_gen_ai_test_logs = self.generate_test_log_data(count=3, log_body=large_log_body) - gen_ai_test_logs = [] - - gen_ai_scopes = [ - "openinference.instrumentation.langchain", - "openinference.instrumentation.crewai", - "opentelemetry.instrumentation.langchain", - "crewai.telemetry", - "openlit.otel.tracing", - ] - - for gen_ai_scope in gen_ai_scopes: - gen_ai_test_logs.extend( - self.generate_test_log_data( - count=1, log_body=large_log_body, instrumentation_scope=InstrumentationScope(gen_ai_scope, "1.0.0") - ) - ) - - test_logs = gen_ai_test_logs + non_gen_ai_test_logs - - for log in test_logs: - self.processor._queue.appendleft(log) - - self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) - - self.assertEqual(len(self.processor._queue), 0) - self.assertEqual(self.mock_exporter.export.call_count, 3 + len(gen_ai_test_logs)) - self.assertEqual(self.mock_exporter.set_gen_ai_log_flag.call_count, len(gen_ai_test_logs)) - - batches = self.mock_exporter.export.call_args_list - - for batch in batches: - self.assertEqual(len(batch[0]), 1) - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.attach", - return_value=MagicMock(), - ) - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.detach") - @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.aws_batch_log_record_processor.set_value") - def test_export_single_batch_some_logs_over_size_limit(self, _, __, ___): - """Should make calls to export smaller sub-batch logs""" - large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) - gen_ai_scope = InstrumentationScope("openinference.instrumentation.langchain", "1.0.0") - small_log_body = "X" * ( - int(self.processor._MAX_LOG_REQUEST_BYTE_SIZE / 10) - self.processor._BASE_LOG_BUFFER_BYTE_SIZE - ) - test_logs = self.generate_test_log_data(count=3, log_body=large_log_body, instrumentation_scope=gen_ai_scope) - # 1st, 2nd, 3rd batch = size 1 - # 4th batch = size 10 - # 5th batch = size 2 - small_logs = self.generate_test_log_data(count=12, log_body=small_log_body, instrumentation_scope=gen_ai_scope) - - test_logs.extend(small_logs) - - for log in test_logs: - self.processor._queue.appendleft(log) - - self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) - - self.assertEqual(len(self.processor._queue), 0) - self.assertEqual(self.mock_exporter.export.call_count, 5) - self.assertEqual(self.mock_exporter.set_gen_ai_log_flag.call_count, 3) - - batches = self.mock_exporter.export.call_args_list - - expected_sizes = { - 0: 1, # 1st batch (index 1) should have 1 log - 1: 1, # 2nd batch (index 1) should have 1 log - 2: 1, # 3rd batch (index 2) should have 1 log - 3: 10, # 4th batch (index 3) should have 10 logs - 4: 2, # 5th batch (index 4) should have 2 logs - } - - for i, call in enumerate(batches): - batch = call[0][0] - expected_size = expected_sizes[i] - self.assertEqual(len(batch), expected_size) - - def generate_test_log_data( - self, log_body: AnyValue, count=5, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0") - ) -> List[LogData]: - logs = [] - for i in range(count): - record = LogRecord( - timestamp=int(time.time_ns()), - trace_id=int(f"0x{i + 1:032x}", 16), - span_id=int(f"0x{i + 1:016x}", 16), - trace_flags=TraceFlags(1), - severity_text="INFO", - severity_number=SeverityNumber.INFO, - body=log_body, - attributes={"test.attribute": f"value-{i + 1}"}, - ) - - log_data = LogData(log_record=record, instrumentation_scope=instrumentation_scope) - logs.append(log_data) - - return logs - - @staticmethod - def generate_nested_log_body(depth=0, expected_body: AnyValue = "test", create_map=True): - if depth < 0: - return expected_body - - if create_map: - return { - "key": TestAwsBatchLogRecordProcessor.generate_nested_log_body(depth - 1, expected_body, create_map) - } - - return [TestAwsBatchLogRecordProcessor.generate_nested_log_body(depth - 1, expected_body, create_map)] diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py deleted file mode 100644 index 9f6d84b32..000000000 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/otlp_aws_logs_exporter_test.py +++ /dev/null @@ -1,180 +0,0 @@ -import time -from unittest import TestCase -from unittest.mock import patch - -import requests -from requests.structures import CaseInsensitiveDict - -from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter -from opentelemetry._logs.severity import SeverityNumber -from opentelemetry.sdk._logs import LogData, LogRecord -from opentelemetry.sdk._logs.export import ( - LogExportResult, -) -from opentelemetry.sdk.util.instrumentation import InstrumentationScope -from opentelemetry.trace import TraceFlags - - -class TestOTLPAwsLogsExporter(TestCase): - _ENDPOINT = "https://logs.us-west-2.amazonaws.com/v1/logs" - good_response = requests.Response() - good_response.status_code = 200 - - non_retryable_response = requests.Response() - non_retryable_response.status_code = 404 - - retryable_response_no_header = requests.Response() - retryable_response_no_header.status_code = 429 - - retryable_response_header = requests.Response() - retryable_response_header.headers = CaseInsensitiveDict({"Retry-After": "10"}) - retryable_response_header.status_code = 503 - - retryable_response_bad_header = requests.Response() - retryable_response_bad_header.headers = CaseInsensitiveDict({"Retry-After": "-12"}) - retryable_response_bad_header.status_code = 503 - - def setUp(self): - self.logs = self.generate_test_log_data() - self.exporter = OTLPAwsLogExporter(endpoint=self._ENDPOINT) - - @patch("requests.Session.request", return_value=good_response) - def test_export_success(self, mock_request): - """Tests that the exporter always compresses the serialized logs with gzip before exporting.""" - result = self.exporter.export(self.logs) - - mock_request.assert_called_once() - - _, kwargs = mock_request.call_args - data = kwargs.get("data", None) - - self.assertEqual(result, LogExportResult.SUCCESS) - - # Gzip first 10 bytes are reserved for metadata headers: - # https://www.loc.gov/preservation/digital/formats/fdd/fdd000599.shtml?loclr=blogsig - self.assertIsNotNone(data) - self.assertTrue(len(data) >= 10) - self.assertEqual(data[0:2], b"\x1f\x8b") - - @patch("requests.Session.request", return_value=good_response) - def test_export_gen_ai_logs(self, mock_request): - """Tests that when set_gen_ai_log_flag is set, the exporter includes the LLO header in the request.""" - - self.exporter.set_gen_ai_log_flag() - - result = self.exporter.export(self.logs) - - mock_request.assert_called_once() - - _, kwargs = mock_request.call_args - headers = kwargs.get("headers", None) - - self.assertEqual(result, LogExportResult.SUCCESS) - self.assertIsNotNone(headers) - self.assertIn(self.exporter._LARGE_LOG_HEADER, headers) - self.assertEqual(headers[self.exporter._LARGE_LOG_HEADER], self.exporter._LARGE_GEN_AI_LOG_PATH_HEADER) - - @patch("requests.Session.request", return_value=good_response) - def test_should_not_export_if_shutdown(self, mock_request): - """Tests that no export request is made if the exporter is shutdown.""" - self.exporter.shutdown() - result = self.exporter.export(self.logs) - - mock_request.assert_not_called() - self.assertEqual(result, LogExportResult.FAILURE) - - @patch("requests.Session.request", return_value=non_retryable_response) - def test_should_not_export_again_if_not_retryable(self, mock_request): - """Tests that only one export request is made if the response status code is non-retryable.""" - result = self.exporter.export(self.logs) - mock_request.assert_called_once() - - self.assertEqual(result, LogExportResult.FAILURE) - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None - ) - @patch("requests.Session.request", return_value=retryable_response_no_header) - def test_should_export_again_with_backoff_if_retryable_and_no_retry_after_header(self, mock_request, mock_sleep): - """Tests that multiple export requests are made with exponential delay if the response status code is retryable. - But there is no Retry-After header.""" - result = self.exporter.export(self.logs) - - # 1, 2, 4, 8, 16, 32 delays - self.assertEqual(mock_sleep.call_count, 6) - - delays = mock_sleep.call_args_list - - for i in range(len(delays)): - self.assertEqual(delays[i][0][0], 2**i) - - # Number of calls: 1 + len(1, 2, 4, 8, 16, 32 delays) - self.assertEqual(mock_request.call_count, 7) - self.assertEqual(result, LogExportResult.FAILURE) - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None - ) - @patch( - "requests.Session.request", - side_effect=[retryable_response_header, retryable_response_header, retryable_response_header, good_response], - ) - def test_should_export_again_with_server_delay_if_retryable_and_retry_after_header(self, mock_request, mock_sleep): - """Tests that multiple export requests are made with the server's suggested - delay if the response status code is retryable and there is a Retry-After header.""" - result = self.exporter.export(self.logs) - delays = mock_sleep.call_args_list - - for i in range(len(delays)): - self.assertEqual(delays[i][0][0], 10) - - self.assertEqual(mock_sleep.call_count, 3) - self.assertEqual(mock_request.call_count, 4) - self.assertEqual(result, LogExportResult.SUCCESS) - - @patch( - "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.sleep", side_effect=lambda x: None - ) - @patch( - "requests.Session.request", - side_effect=[ - retryable_response_bad_header, - retryable_response_bad_header, - retryable_response_bad_header, - good_response, - ], - ) - def test_should_export_again_with_backoff_delay_if_retryable_and_bad_retry_after_header( - self, mock_request, mock_sleep - ): - """Tests that multiple export requests are made with exponential delay if the response status code is retryable. - but the Retry-After header ins invalid or malformed.""" - result = self.exporter.export(self.logs) - delays = mock_sleep.call_args_list - - for i in range(len(delays)): - self.assertEqual(delays[i][0][0], 2**i) - - self.assertEqual(mock_sleep.call_count, 3) - self.assertEqual(mock_request.call_count, 4) - self.assertEqual(result, LogExportResult.SUCCESS) - - def generate_test_log_data(self, count=5): - logs = [] - for i in range(count): - record = LogRecord( - timestamp=int(time.time_ns()), - trace_id=int(f"0x{i + 1:032x}", 16), - span_id=int(f"0x{i + 1:016x}", 16), - trace_flags=TraceFlags(1), - severity_text="INFO", - severity_number=SeverityNumber.INFO, - body=f"Test log {i + 1}", - attributes={"test.attribute": f"value-{i + 1}"}, - ) - - log_data = LogData(log_record=record, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0")) - - logs.append(log_data) - - return logs diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py new file mode 100644 index 000000000..156f177cb --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_aws_cw_otlp_batch_log_record_processor.py @@ -0,0 +1,310 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import time +import unittest +from typing import List +from unittest.mock import MagicMock, patch + +from amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor import ( + AwsCloudWatchOtlpBatchLogRecordProcessor, + BatchLogExportStrategy, +) +from opentelemetry._logs.severity import SeverityNumber +from opentelemetry.sdk._logs import LogData, LogRecord +from opentelemetry.sdk._logs.export import LogExportResult +from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.trace import TraceFlags +from opentelemetry.util.types import AnyValue + + +class TestAwsBatchLogRecordProcessor(unittest.TestCase): + + def setUp(self): + self.mock_exporter = MagicMock() + self.mock_exporter.export.return_value = LogExportResult.SUCCESS + + self.processor = AwsCloudWatchOtlpBatchLogRecordProcessor(exporter=self.mock_exporter) + + def test_process_log_data_nested_structure(self): + """Tests that the processor correctly handles nested structures (dict/list)""" + log_body = "X" * 400 + log_key = "test" + log_depth = 2 + + nested_dict_log = self.generate_test_log_data( + log_body=log_body, log_key=log_key, log_body_depth=log_depth, count=1, create_map=True + ) + nested_array_log = self.generate_test_log_data( + log_body=log_body, log_key=log_key, log_body_depth=log_depth, count=1, create_map=False + ) + + expected_dict_size = len(log_key) * log_depth + len(log_body) + expected_array_size = len(log_body) + + dict_size = self.processor._estimate_log_size(log=nested_dict_log[0], depth=log_depth) + array_size = self.processor._estimate_log_size(log=nested_array_log[0], depth=log_depth) + + self.assertEqual(dict_size - self.processor._BASE_LOG_BUFFER_BYTE_SIZE, expected_dict_size) + self.assertEqual(array_size - self.processor._BASE_LOG_BUFFER_BYTE_SIZE, expected_array_size) + + def test_process_log_data_with_attributes(self): + """Tests that the processor correctly handles both body and attributes""" + log_body = "test_body" + attr_key = "attr_key" + attr_value = "attr_value" + + record = LogRecord( + timestamp=int(time.time_ns()), + trace_id=0x123456789ABCDEF0123456789ABCDEF0, + span_id=0x123456789ABCDEF0, + trace_flags=TraceFlags(1), + severity_text="INFO", + severity_number=SeverityNumber.INFO, + body=log_body, + attributes={attr_key: attr_value}, + ) + log_data = LogData(log_record=record, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0")) + + expected_size = len(log_body) + len(attr_key) + len(attr_value) + actual_size = self.processor._estimate_log_size(log_data) + + self.assertEqual(actual_size - self.processor._BASE_LOG_BUFFER_BYTE_SIZE, expected_size) + + def test_process_log_data_nested_structure_exceeds_depth(self): + """Tests that the processor cuts off calculation for nested structure that exceeds the depth limit""" + max_depth = 0 + calculated_body = "X" * 400 + log_body = { + "calculated": "X" * 400, + "restOfThisLogWillBeTruncated": {"truncated": {"test": "X" * self.processor._MAX_LOG_REQUEST_BYTE_SIZE}}, + } + + expected_size = self.processor._BASE_LOG_BUFFER_BYTE_SIZE + ( + len("calculated") + len(calculated_body) + len("restOfThisLogWillBeTruncated") + ) + + test_logs = self.generate_test_log_data(log_body=log_body, count=1) + + # Only calculates log size of up to depth of 0 + dict_size = self.processor._estimate_log_size(log=test_logs[0], depth=max_depth) + + self.assertEqual(dict_size, expected_size) + + def test_process_log_data_nested_structure_size_exceeds_max_log_size(self): + """Tests that the processor returns prematurely if the size already exceeds _MAX_LOG_REQUEST_BYTE_SIZE""" + # Should stop calculation at bigKey + biggerKey and not calculate the content of biggerKey + log_body = { + "bigKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE), + "biggerKey": "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE * 100), + } + + expected_size = ( + self.processor._BASE_LOG_BUFFER_BYTE_SIZE + + self.processor._MAX_LOG_REQUEST_BYTE_SIZE + + len("bigKey") + + len("biggerKey") + ) + + nest_dict_log = self.generate_test_log_data(log_body=log_body, count=1, create_map=True) + nest_array_log = self.generate_test_log_data(log_body=log_body, count=1, create_map=False) + + dict_size = self.processor._estimate_log_size(log=nest_dict_log[0]) + array_size = self.processor._estimate_log_size(log=nest_array_log[0]) + + self.assertEqual(dict_size, expected_size) + self.assertEqual(array_size, expected_size) + + def test_process_log_data_primitive(self): + + primitives: List[AnyValue] = ["test", b"test", 1, 1.2, True, False, None, "深入 Python", "calfé"] + expected_sizes = [4, 4, 1, 3, 4, 5, 0, 2 * 4 + len(" Python"), 1 * 4 + len("calf")] + + for index, primitive in enumerate(primitives): + log = self.generate_test_log_data(log_body=primitive, count=1) + expected_size = self.processor._BASE_LOG_BUFFER_BYTE_SIZE + expected_sizes[index] + actual_size = self.processor._estimate_log_size(log[0]) + self.assertEqual(actual_size, expected_size) + + def test_process_log_data_with_cycle(self): + """Test that processor handles processing logs with circular references only once""" + cyclic_dict: dict = {"data": "test"} + cyclic_dict["self_ref"] = cyclic_dict + + log = self.generate_test_log_data(log_body=cyclic_dict, count=1) + expected_size = self.processor._BASE_LOG_BUFFER_BYTE_SIZE + len("data") + len("self_ref") + len("test") + actual_size = self.processor._estimate_log_size(log[0]) + self.assertEqual(actual_size, expected_size) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.set_value") + def test_export_single_batch_under_size_limit(self, _, __, ___): + """Tests that export is only called once if a single batch is under the size limit""" + log_count = 10 + log_body = "test" + test_logs = self.generate_test_log_data(log_body=log_body, count=log_count) + total_data_size = 0 + + for log in test_logs: + size = self.processor._estimate_log_size(log) + total_data_size += size + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + args, _ = self.mock_exporter.export.call_args + actual_batch = args[0] + + self.assertLess(total_data_size, self.processor._MAX_LOG_REQUEST_BYTE_SIZE) + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(len(actual_batch), log_count) + self.mock_exporter.export.assert_called_once() + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.set_value") + def test_export_single_batch_all_logs_over_size_limit(self, _, __, ___): + """Should make multiple export calls of batch size 1 to export logs of size > 1 MB.""" + + large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) + test_logs = self.generate_test_log_data(log_body=large_log_body, count=15) + + for log in test_logs: + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(self.mock_exporter.export.call_count, len(test_logs)) + + batches = self.mock_exporter.export.call_args_list + + for batch in batches: + self.assertEqual(len(batch[0]), 1) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.set_value") + def test_export_single_batch_some_logs_over_size_limit(self, _, __, ___): + """Should make calls to export smaller sub-batch logs""" + large_log_body = "X" * (self.processor._MAX_LOG_REQUEST_BYTE_SIZE + 1) + small_log_body = "X" * ( + self.processor._MAX_LOG_REQUEST_BYTE_SIZE // 10 - self.processor._BASE_LOG_BUFFER_BYTE_SIZE + ) + + large_logs = self.generate_test_log_data(log_body=large_log_body, count=3) + small_logs = self.generate_test_log_data(log_body=small_log_body, count=12) + + # 1st, 2nd, 3rd batch = size 1 + # 4th batch = size 10 + # 5th batch = size 2 + test_logs = large_logs + small_logs + + for log in test_logs: + self.processor._queue.appendleft(log) + + self.processor._export(batch_strategy=BatchLogExportStrategy.EXPORT_ALL) + + self.assertEqual(len(self.processor._queue), 0) + self.assertEqual(self.mock_exporter.export.call_count, 5) + + batches = self.mock_exporter.export.call_args_list + + expected_sizes = { + 0: 1, # 1st batch (index 1) should have 1 log + 1: 1, # 2nd batch (index 1) should have 1 log + 2: 1, # 3rd batch (index 2) should have 1 log + 3: 10, # 4th batch (index 3) should have 10 logs + 4: 2, # 5th batch (index 4) should have 2 logs + } + + for index, call in enumerate(batches): + batch = call[0][0] + expected_size = expected_sizes[index] + self.assertEqual(len(batch), expected_size) + + def test_force_flush_returns_false_when_shutdown(self): + """Tests that force_flush returns False when processor is shutdown""" + self.processor.shutdown() + result = self.processor.force_flush() + + # Verify force_flush returns False and no export is called + self.assertFalse(result) + self.mock_exporter.export.assert_not_called() + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.attach", + return_value=MagicMock(), + ) + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.detach") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor.set_value") + def test_force_flush_exports_only_one_batch(self, _, __, ___): + """Tests that force_flush should try to at least export one batch of logs. Rest of the logs will be dropped""" + # Set max_export_batch_size to 5 to limit batch size + self.processor._max_export_batch_size = 5 + self.processor._shutdown = False + + # Add 6 logs to queue, after the export there should be 1 log remaining + log_count = 6 + test_logs = self.generate_test_log_data(log_body="test message", count=log_count) + + for log in test_logs: + self.processor._queue.appendleft(log) + + self.assertEqual(len(self.processor._queue), log_count) + + result = self.processor.force_flush() + + self.assertTrue(result) + # 45 logs should remain + self.assertEqual(len(self.processor._queue), 1) + self.mock_exporter.export.assert_called_once() + + # Verify only one batch of 5 logs was exported + args, _ = self.mock_exporter.export.call_args + exported_batch = args[0] + self.assertEqual(len(exported_batch), 5) + + @staticmethod + def generate_test_log_data( + log_body, + log_key="key", + log_body_depth=0, + count=5, + create_map=True, + ) -> List[LogData]: + + def generate_nested_value(depth, value, create_map=True) -> AnyValue: + if depth <= 0: + return value + + if create_map: + return {log_key: generate_nested_value(depth - 1, value, True)} + + return [generate_nested_value(depth - 1, value, False)] + + logs = [] + + for _ in range(count): + record = LogRecord( + timestamp=int(time.time_ns()), + trace_id=0x123456789ABCDEF0123456789ABCDEF0, + span_id=0x123456789ABCDEF0, + trace_flags=TraceFlags(1), + severity_text="INFO", + severity_number=SeverityNumber.INFO, + body=generate_nested_value(log_body_depth, log_body, create_map), + ) + + log_data = LogData(log_record=record, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0")) + logs.append(log_data) + + return logs diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_otlp_aws_logs_exporter.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_otlp_aws_logs_exporter.py new file mode 100644 index 000000000..5c75f63de --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/logs/test_otlp_aws_logs_exporter.py @@ -0,0 +1,250 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import time +from unittest import TestCase +from unittest.mock import patch + +import requests +from requests.structures import CaseInsensitiveDict + +from amazon.opentelemetry.distro._utils import get_aws_session +from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import _MAX_RETRYS, OTLPAwsLogExporter +from opentelemetry._logs.severity import SeverityNumber +from opentelemetry.sdk._logs import LogData, LogRecord +from opentelemetry.sdk._logs.export import LogExportResult +from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.trace import TraceFlags + + +class TestOTLPAwsLogsExporter(TestCase): + _ENDPOINT = "https://logs.us-west-2.amazonaws.com/v1/logs" + good_response = requests.Response() + good_response.status_code = 200 + + non_retryable_response = requests.Response() + non_retryable_response.status_code = 404 + + retryable_response_no_header = requests.Response() + retryable_response_no_header.status_code = 429 + + retryable_response_header = requests.Response() + retryable_response_header.headers = CaseInsensitiveDict({"Retry-After": "10"}) + retryable_response_header.status_code = 503 + + retryable_response_bad_header = requests.Response() + retryable_response_bad_header.headers = CaseInsensitiveDict({"Retry-After": "-12"}) + retryable_response_bad_header.status_code = 503 + + def setUp(self): + self.logs = self.generate_test_log_data() + self.exporter = OTLPAwsLogExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=self._ENDPOINT) + + @patch("requests.Session.post", return_value=good_response) + def test_export_success(self, mock_request): + """Tests that the exporter always compresses the serialized logs with gzip before exporting.""" + result = self.exporter.export(self.logs) + + mock_request.assert_called_once() + + _, kwargs = mock_request.call_args + data = kwargs.get("data", None) + + self.assertEqual(result, LogExportResult.SUCCESS) + + # Gzip first 10 bytes are reserved for metadata headers: + # https://www.loc.gov/preservation/digital/formats/fdd/fdd000599.shtml?loclr=blogsig + self.assertIsNotNone(data) + self.assertTrue(len(data) >= 10) + self.assertEqual(data[0:2], b"\x1f\x8b") + + @patch("requests.Session.post", return_value=good_response) + def test_should_not_export_if_shutdown(self, mock_request): + """Tests that no export request is made if the exporter is shutdown.""" + self.exporter.shutdown() + result = self.exporter.export(self.logs) + + mock_request.assert_not_called() + self.assertEqual(result, LogExportResult.FAILURE) + + @patch("requests.Session.post", return_value=non_retryable_response) + def test_should_not_export_again_if_not_retryable(self, mock_request): + """Tests that only one export request is made if the response status code is non-retryable.""" + result = self.exporter.export(self.logs) + mock_request.assert_called_once() + + self.assertEqual(result, LogExportResult.FAILURE) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.Event.wait", + side_effect=lambda x: False, + ) + @patch("requests.Session.post", return_value=retryable_response_no_header) + def test_should_export_again_with_backoff_if_retryable_and_no_retry_after_header(self, mock_request, mock_wait): + """Tests that multiple export requests are made with exponential delay if the response status code is retryable. + But there is no Retry-After header.""" + self.exporter._timeout = 10000 # Large timeout to avoid early exit + result = self.exporter.export(self.logs) + + self.assertEqual(mock_wait.call_count, _MAX_RETRYS - 1) + + delays = mock_wait.call_args_list + + for index, delay in enumerate(delays): + expected_base = 2**index + actual_delay = delay[0][0] + # Assert delay is within jitter range: base * [0.8, 1.2] + self.assertGreaterEqual(actual_delay, expected_base * 0.8) + self.assertLessEqual(actual_delay, expected_base * 1.2) + + self.assertEqual(mock_request.call_count, _MAX_RETRYS) + self.assertEqual(result, LogExportResult.FAILURE) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.Event.wait", + side_effect=lambda x: False, + ) + @patch( + "requests.Session.post", + side_effect=[retryable_response_header, retryable_response_header, retryable_response_header, good_response], + ) + def test_should_export_again_with_server_delay_if_retryable_and_retry_after_header(self, mock_request, mock_wait): + """Tests that multiple export requests are made with the server's suggested + delay if the response status code is retryable and there is a Retry-After header.""" + self.exporter._timeout = 10000 # Large timeout to avoid early exit + result = self.exporter.export(self.logs) + + delays = mock_wait.call_args_list + + for delay in delays: + self.assertEqual(delay[0][0], 10) + + self.assertEqual(mock_wait.call_count, 3) + self.assertEqual(mock_request.call_count, 4) + self.assertEqual(result, LogExportResult.SUCCESS) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.Event.wait", + side_effect=lambda x: False, + ) + @patch( + "requests.Session.post", + side_effect=[ + retryable_response_bad_header, + retryable_response_bad_header, + retryable_response_bad_header, + good_response, + ], + ) + def test_should_export_again_with_backoff_delay_if_retryable_and_bad_retry_after_header( + self, mock_request, mock_wait + ): + """Tests that multiple export requests are made with exponential delay if the response status code is retryable. + but the Retry-After header is invalid or malformed.""" + self.exporter._timeout = 10000 # Large timeout to avoid early exit + result = self.exporter.export(self.logs) + + delays = mock_wait.call_args_list + + for index, delay in enumerate(delays): + expected_base = 2**index + actual_delay = delay[0][0] + # Assert delay is within jitter range: base * [0.8, 1.2] + self.assertGreaterEqual(actual_delay, expected_base * 0.8) + self.assertLessEqual(actual_delay, expected_base * 1.2) + + self.assertEqual(mock_wait.call_count, 3) + self.assertEqual(mock_request.call_count, 4) + self.assertEqual(result, LogExportResult.SUCCESS) + + @patch("requests.Session.post", side_effect=[requests.exceptions.ConnectionError(), good_response]) + def test_export_connection_error_retry(self, mock_request): + """Tests that the exporter retries on ConnectionError.""" + result = self.exporter.export(self.logs) + + self.assertEqual(mock_request.call_count, 2) + self.assertEqual(result, LogExportResult.SUCCESS) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.Event.wait", + side_effect=lambda x: False, + ) + @patch("requests.Session.post", return_value=retryable_response_no_header) + def test_should_stop_retrying_when_deadline_exceeded(self, mock_request, mock_wait): + """Tests that the exporter stops retrying when the deadline is exceeded.""" + self.exporter._timeout = 5 # Short timeout to trigger deadline check + + with patch("amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.time") as mock_time: + # First call returns start time, subsequent calls simulate time passing + mock_time.side_effect = [0, 0, 1, 2, 4, 8] # Exponential backoff would be 1, 2, 4 seconds + + result = self.exporter.export(self.logs) + + # Should stop before max retries due to deadline + self.assertLess(mock_wait.call_count, _MAX_RETRYS) + self.assertLess(mock_request.call_count, _MAX_RETRYS + 1) + self.assertEqual(result, LogExportResult.FAILURE) + + # Verify total time passed is at the timeout limit + self.assertGreaterEqual(5, self.exporter._timeout) + + @patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.Event.wait", + side_effect=lambda x: True, + ) + @patch("requests.Session.post", return_value=retryable_response_no_header) + def test_export_interrupted_by_shutdown(self, mock_request, mock_wait): + """Tests that export can be interrupted by shutdown during retry wait.""" + self.exporter._timeout = 10000 + + result = self.exporter.export(self.logs) + + # Should make one request, then get interrupted during retry wait + self.assertEqual(mock_request.call_count, 1) + self.assertEqual(result, LogExportResult.FAILURE) + + @patch("requests.Session.post", return_value=good_response) + def test_export_with_log_group_and_stream_headers(self, mock_request): + """Tests that log_group and log_stream are properly set as headers when provided.""" + log_group = "test-log-group" + log_stream = "test-log-stream" + + exporter = OTLPAwsLogExporter( + session=get_aws_session(), + aws_region="us-east-1", + endpoint=self._ENDPOINT, + log_group=log_group, + log_stream=log_stream, + ) + + result = exporter.export(self.logs) + + mock_request.assert_called_once() + self.assertEqual(result, LogExportResult.SUCCESS) + + # Verify headers contain log group and stream + session_headers = exporter._session.headers + self.assertIn("x-aws-log-group", session_headers) + self.assertIn("x-aws-log-stream", session_headers) + self.assertEqual(session_headers["x-aws-log-group"], log_group) + self.assertEqual(session_headers["x-aws-log-stream"], log_stream) + + @staticmethod + def generate_test_log_data(count=5): + logs = [] + for index in range(count): + record = LogRecord( + timestamp=int(time.time_ns()), + trace_id=int(f"0x{index + 1:032x}", 16), + span_id=int(f"0x{index + 1:016x}", 16), + trace_flags=TraceFlags(1), + severity_text="INFO", + severity_number=SeverityNumber.INFO, + body=f"Test log {index + 1}", + attributes={"test.attribute": f"value-{index + 1}"}, + ) + + log_data = LogData(log_record=record, instrumentation_scope=InstrumentationScope("test-scope", "1.0.0")) + + logs.append(log_data) + + return logs diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/traces/test_otlp_aws_span_exporter.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/traces/test_otlp_aws_span_exporter.py new file mode 100644 index 000000000..1553dd8e2 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/exporter/otlp/aws/traces/test_otlp_aws_span_exporter.py @@ -0,0 +1,196 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from amazon.opentelemetry.distro._utils import get_aws_session +from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk._logs import LoggerProvider +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExportResult + + +class TestOTLPAwsSpanExporter(TestCase): + def test_init_with_logger_provider(self): + # Test initialization with logger_provider + mock_logger_provider = MagicMock(spec=LoggerProvider) + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + + exporter = OTLPAwsSpanExporter( + session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint, logger_provider=mock_logger_provider + ) + + self.assertEqual(exporter._logger_provider, mock_logger_provider) + self.assertEqual(exporter._aws_region, "us-east-1") + + def test_init_without_logger_provider(self): + # Test initialization without logger_provider (default behavior) + endpoint = "https://xray.us-west-2.amazonaws.com/v1/traces" + + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-west-2", endpoint=endpoint) + + self.assertIsNone(exporter._logger_provider) + self.assertEqual(exporter._aws_region, "us-west-2") + self.assertIsNone(exporter._llo_handler) + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + def test_ensure_llo_handler_when_disabled(self, mock_is_enabled): + # Test _ensure_llo_handler when agent observability is disabled + mock_is_enabled.return_value = False + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + result = exporter._ensure_llo_handler() + + self.assertFalse(result) + self.assertIsNone(exporter._llo_handler) + mock_is_enabled.assert_called_once() + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.get_logger_provider") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.LLOHandler") + def test_ensure_llo_handler_lazy_initialization( + self, mock_llo_handler_class, mock_is_enabled, mock_get_logger_provider + ): + # Test lazy initialization of LLO handler when enabled + mock_is_enabled.return_value = True + mock_logger_provider = MagicMock(spec=LoggerProvider) + mock_get_logger_provider.return_value = mock_logger_provider + mock_llo_handler = MagicMock() + mock_llo_handler_class.return_value = mock_llo_handler + + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + + # First call should initialize + result = exporter._ensure_llo_handler() + + self.assertTrue(result) + self.assertEqual(exporter._llo_handler, mock_llo_handler) + mock_llo_handler_class.assert_called_once_with(mock_logger_provider) + mock_get_logger_provider.assert_called_once() + + # Second call should not re-initialize + mock_llo_handler_class.reset_mock() + mock_get_logger_provider.reset_mock() + + result = exporter._ensure_llo_handler() + + self.assertTrue(result) + mock_llo_handler_class.assert_not_called() + mock_get_logger_provider.assert_not_called() + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.get_logger_provider") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + def test_ensure_llo_handler_with_existing_logger_provider(self, mock_is_enabled, mock_get_logger_provider): + # Test when logger_provider is already provided + mock_is_enabled.return_value = True + mock_logger_provider = MagicMock(spec=LoggerProvider) + + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + exporter = OTLPAwsSpanExporter( + session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint, logger_provider=mock_logger_provider + ) + + with patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.LLOHandler" + ) as mock_llo_handler_class: + mock_llo_handler = MagicMock() + mock_llo_handler_class.return_value = mock_llo_handler + + result = exporter._ensure_llo_handler() + + self.assertTrue(result) + self.assertEqual(exporter._llo_handler, mock_llo_handler) + mock_llo_handler_class.assert_called_once_with(mock_logger_provider) + mock_get_logger_provider.assert_not_called() + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.get_logger_provider") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + def test_ensure_llo_handler_get_logger_provider_fails(self, mock_is_enabled, mock_get_logger_provider): + # Test when get_logger_provider raises exception + mock_is_enabled.return_value = True + mock_get_logger_provider.side_effect = Exception("Failed to get logger provider") + + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + + result = exporter._ensure_llo_handler() + + self.assertFalse(result) + self.assertIsNone(exporter._llo_handler) + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + def test_export_with_llo_disabled(self, mock_is_enabled): + # Test export when LLO is disabled + mock_is_enabled.return_value = False + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + + # Mock the parent class export method + with patch.object(OTLPSpanExporter, "export") as mock_parent_export: + mock_parent_export.return_value = SpanExportResult.SUCCESS + + spans = [MagicMock(spec=ReadableSpan), MagicMock(spec=ReadableSpan)] + result = exporter.export(spans) + + self.assertEqual(result, SpanExportResult.SUCCESS) + mock_parent_export.assert_called_once_with(spans) + self.assertIsNone(exporter._llo_handler) + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.get_logger_provider") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.LLOHandler") + def test_export_with_llo_enabled(self, mock_llo_handler_class, mock_get_logger_provider, mock_is_enabled): + # Test export when LLO is enabled and successfully processes spans + mock_is_enabled.return_value = True + mock_logger_provider = MagicMock(spec=LoggerProvider) + mock_get_logger_provider.return_value = mock_logger_provider + + mock_llo_handler = MagicMock() + mock_llo_handler_class.return_value = mock_llo_handler + + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + + # Mock spans and processed spans + original_spans = [MagicMock(spec=ReadableSpan), MagicMock(spec=ReadableSpan)] + processed_spans = [MagicMock(spec=ReadableSpan), MagicMock(spec=ReadableSpan)] + mock_llo_handler.process_spans.return_value = processed_spans + + # Mock the parent class export method + with patch.object(OTLPSpanExporter, "export") as mock_parent_export: + mock_parent_export.return_value = SpanExportResult.SUCCESS + + result = exporter.export(original_spans) + + self.assertEqual(result, SpanExportResult.SUCCESS) + mock_llo_handler.process_spans.assert_called_once_with(original_spans) + mock_parent_export.assert_called_once_with(processed_spans) + + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.get_logger_provider") + @patch("amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.LLOHandler") + def test_export_with_llo_processing_failure( + self, mock_llo_handler_class, mock_get_logger_provider, mock_is_enabled + ): + # Test export when LLO processing fails + mock_is_enabled.return_value = True + mock_logger_provider = MagicMock(spec=LoggerProvider) + mock_get_logger_provider.return_value = mock_logger_provider + + mock_llo_handler = MagicMock() + mock_llo_handler_class.return_value = mock_llo_handler + mock_llo_handler.process_spans.side_effect = Exception("LLO processing failed") + + endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" + exporter = OTLPAwsSpanExporter(session=get_aws_session(), aws_region="us-east-1", endpoint=endpoint) + + spans = [MagicMock(spec=ReadableSpan), MagicMock(spec=ReadableSpan)] + + result = exporter.export(spans) + + self.assertEqual(result, SpanExportResult.FAILURE) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_base.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_base.py new file mode 100644 index 000000000..9f45da93d --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_base.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Base test utilities for LLO Handler tests.""" +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from amazon.opentelemetry.distro.llo_handler import LLOHandler +from opentelemetry.sdk._logs import LoggerProvider +from opentelemetry.sdk.trace import ReadableSpan, SpanContext +from opentelemetry.trace import SpanKind, TraceFlags, TraceState + + +class LLOHandlerTestBase(TestCase): + """Base class with common setup and utilities for LLO Handler tests.""" + + def setUp(self): + self.logger_provider_mock = MagicMock(spec=LoggerProvider) + self.event_logger_mock = MagicMock() + self.event_logger_provider_mock = MagicMock() + self.event_logger_provider_mock.get_event_logger.return_value = self.event_logger_mock + + with patch( + "amazon.opentelemetry.distro.llo_handler.EventLoggerProvider", return_value=self.event_logger_provider_mock + ): + self.llo_handler = LLOHandler(self.logger_provider_mock) + + @staticmethod + def _create_mock_span(attributes=None, kind=SpanKind.INTERNAL, preserve_none=False): + """ + Create a mock ReadableSpan for testing. + + Args: + attributes: Span attributes dictionary. Defaults to empty dict unless preserve_none=True + kind: The span kind (default: INTERNAL) + preserve_none: If True, keeps None attributes instead of converting to empty dict + + Returns: + MagicMock: A mock span with context, attributes, and basic properties set + """ + if attributes is None and not preserve_none: + attributes = {} + + span_context = SpanContext( + trace_id=0x123456789ABCDEF0123456789ABCDEF0, + span_id=0x123456789ABCDEF0, + is_remote=False, + trace_flags=TraceFlags.SAMPLED, + trace_state=TraceState.get_default(), + ) + + mock_span = MagicMock(spec=ReadableSpan) + mock_span.context = span_context + mock_span.attributes = attributes + mock_span.kind = kind + mock_span.start_time = 1234567890 + + return mock_span diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_collection.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_collection.py new file mode 100644 index 000000000..a86cebb20 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_collection.py @@ -0,0 +1,269 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LLO Handler message collection functionality.""" + +from test_llo_handler_base import LLOHandlerTestBase + + +class TestLLOHandlerCollection(LLOHandlerTestBase): + """Test message collection from various frameworks.""" + + def test_collect_gen_ai_prompt_messages_system_role(self): + """ + Verify indexed prompt messages with system role are collected with correct content, role, and source. + """ + attributes = { + "gen_ai.prompt.0.content": "system instruction", + "gen_ai.prompt.0.role": "system", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "system instruction") + self.assertEqual(message["role"], "system") + self.assertEqual(message["source"], "prompt") + + def test_collect_gen_ai_prompt_messages_user_role(self): + """ + Verify indexed prompt messages with user role are collected with correct content, role, and source. + """ + attributes = { + "gen_ai.prompt.0.content": "user question", + "gen_ai.prompt.0.role": "user", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "user question") + self.assertEqual(message["role"], "user") + self.assertEqual(message["source"], "prompt") + + def test_collect_gen_ai_prompt_messages_assistant_role(self): + """ + Verify indexed prompt messages with assistant role are collected with correct content, role, and source. + """ + attributes = { + "gen_ai.prompt.1.content": "assistant response", + "gen_ai.prompt.1.role": "assistant", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "assistant response") + self.assertEqual(message["role"], "assistant") + self.assertEqual(message["source"], "prompt") + + def test_collect_gen_ai_prompt_messages_function_role(self): + """ + Verify indexed prompt messages with non-standard 'function' role are collected correctly. + """ + attributes = { + "gen_ai.prompt.2.content": "function data", + "gen_ai.prompt.2.role": "function", + } + + span = self._create_mock_span(attributes) + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "function data") + self.assertEqual(message["role"], "function") + self.assertEqual(message["source"], "prompt") + + def test_collect_gen_ai_prompt_messages_unknown_role(self): + """ + Verify indexed prompt messages with unknown role are collected with the role preserved. + """ + attributes = { + "gen_ai.prompt.3.content": "unknown type content", + "gen_ai.prompt.3.role": "unknown", + } + + span = self._create_mock_span(attributes) + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "unknown type content") + self.assertEqual(message["role"], "unknown") + self.assertEqual(message["source"], "prompt") + + def test_collect_gen_ai_completion_messages_assistant_role(self): + """ + Verify indexed completion messages with assistant role are collected with source='completion'. + """ + attributes = { + "gen_ai.completion.0.content": "assistant completion", + "gen_ai.completion.0.role": "assistant", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "assistant completion") + self.assertEqual(message["role"], "assistant") + self.assertEqual(message["source"], "completion") + + def test_collect_gen_ai_completion_messages_other_role(self): + """ + Verify indexed completion messages with custom roles are collected with source='completion'. + """ + attributes = { + "gen_ai.completion.1.content": "other completion", + "gen_ai.completion.1.role": "other", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "other completion") + self.assertEqual(message["role"], "other") + self.assertEqual(message["source"], "completion") + + def test_collect_all_llo_messages_none_attributes(self): + """ + Verify _collect_all_llo_messages returns empty list when attributes are None. + """ + span = self._create_mock_span(None, preserve_none=True) + + messages = self.llo_handler._collect_all_llo_messages(span, None) + + self.assertEqual(messages, []) + self.assertEqual(len(messages), 0) + + def test_collect_indexed_messages_none_attributes(self): + """ + Verify _collect_indexed_messages returns empty list when attributes are None. + """ + messages = self.llo_handler._collect_indexed_messages(None) + + self.assertEqual(messages, []) + self.assertEqual(len(messages), 0) + + def test_collect_indexed_messages_missing_role(self): + """ + Verify indexed messages use default roles when role attributes are missing. + """ + attributes = { + "gen_ai.prompt.0.content": "prompt without role", + "gen_ai.completion.0.content": "completion without role", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + prompt_msg = next((m for m in messages if m["content"] == "prompt without role"), None) + self.assertIsNotNone(prompt_msg) + self.assertEqual(prompt_msg["role"], "unknown") + self.assertEqual(prompt_msg["source"], "prompt") + + completion_msg = next((m for m in messages if m["content"] == "completion without role"), None) + self.assertIsNotNone(completion_msg) + self.assertEqual(completion_msg["role"], "unknown") + self.assertEqual(completion_msg["source"], "completion") + + def test_indexed_messages_with_out_of_order_indices(self): + """ + Test that indexed messages are sorted correctly even with out-of-order indices + """ + attributes = { + "gen_ai.prompt.5.content": "fifth prompt", + "gen_ai.prompt.5.role": "user", + "gen_ai.prompt.1.content": "first prompt", + "gen_ai.prompt.1.role": "system", + "gen_ai.prompt.3.content": "third prompt", + "gen_ai.prompt.3.role": "user", + "llm.input_messages.10.message.content": "tenth message", + "llm.input_messages.10.message.role": "assistant", + "llm.input_messages.2.message.content": "second message", + "llm.input_messages.2.message.role": "user", + } + + messages = self.llo_handler._collect_indexed_messages(attributes) + + # Messages should be sorted by pattern key first, then by index + self.assertEqual(len(messages), 5) + + # Check gen_ai.prompt messages are in order + gen_ai_messages = [m for m in messages if "prompt" in m["source"]] + self.assertEqual(gen_ai_messages[0]["content"], "first prompt") + self.assertEqual(gen_ai_messages[1]["content"], "third prompt") + self.assertEqual(gen_ai_messages[2]["content"], "fifth prompt") + + # Check llm.input_messages are in order + llm_messages = [m for m in messages if m["content"] in ["second message", "tenth message"]] + self.assertEqual(llm_messages[0]["content"], "second message") + self.assertEqual(llm_messages[1]["content"], "tenth message") + + def test_collect_methods_message_format(self): + """ + Verify all message collection methods return consistent message format with content, role, and source fields. + """ + attributes = { + "gen_ai.prompt.0.content": "prompt", + "gen_ai.prompt.0.role": "user", + "gen_ai.completion.0.content": "response", + "gen_ai.completion.0.role": "assistant", + "traceloop.entity.input": "input", + "gen_ai.prompt": "direct prompt", + "input.value": "inference input", + } + + span = self._create_mock_span(attributes) + + prompt_messages = self.llo_handler._collect_all_llo_messages(span, attributes) + for msg in prompt_messages: + self.assertIn("content", msg) + self.assertIn("role", msg) + self.assertIn("source", msg) + self.assertIsInstance(msg["content"], str) + self.assertIsInstance(msg["role"], str) + self.assertIsInstance(msg["source"], str) + + completion_messages = self.llo_handler._collect_all_llo_messages(span, attributes) + for msg in completion_messages: + self.assertIn("content", msg) + self.assertIn("role", msg) + self.assertIn("source", msg) + + traceloop_messages = self.llo_handler._collect_all_llo_messages(span, attributes) + for msg in traceloop_messages: + self.assertIn("content", msg) + self.assertIn("role", msg) + self.assertIn("source", msg) + + openlit_messages = self.llo_handler._collect_all_llo_messages(span, attributes) + for msg in openlit_messages: + self.assertIn("content", msg) + self.assertIn("role", msg) + self.assertIn("source", msg) + + openinference_messages = self.llo_handler._collect_all_llo_messages(span, attributes) + for msg in openinference_messages: + self.assertIn("content", msg) + self.assertIn("role", msg) + self.assertIn("source", msg) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_events.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_events.py new file mode 100644 index 000000000..5d90ebc77 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_events.py @@ -0,0 +1,651 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LLO Handler event emission functionality.""" + +from unittest.mock import MagicMock, patch + +from test_llo_handler_base import LLOHandlerTestBase + + +class TestLLOHandlerEvents(LLOHandlerTestBase): + """Test event emission and formatting functionality.""" + + def test_emit_llo_attributes(self): + """ + Verify _emit_llo_attributes creates a single consolidated event with input/output message groups + containing all LLO content from various frameworks. + """ + attributes = { + "gen_ai.prompt.0.content": "prompt content", + "gen_ai.prompt.0.role": "user", + "gen_ai.completion.0.content": "completion content", + "gen_ai.completion.0.role": "assistant", + "traceloop.entity.input": "traceloop input", + "traceloop.entity.name": "entity_name", + "gen_ai.agent.actual_output": "agent output", + "crewai.crew.tasks_output": "tasks output", + "crewai.crew.result": "crew result", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + self.assertEqual(emitted_event.name, "test.scope") + self.assertEqual(emitted_event.timestamp, span.end_time) + self.assertEqual(emitted_event.trace_id, span.context.trace_id) + self.assertEqual(emitted_event.span_id, span.context.span_id) + self.assertEqual(emitted_event.trace_flags, span.context.trace_flags) + + event_body = emitted_event.body + self.assertIn("input", event_body) + self.assertIn("output", event_body) + self.assertIn("messages", event_body["input"]) + self.assertIn("messages", event_body["output"]) + + input_messages = event_body["input"]["messages"] + self.assertEqual(len(input_messages), 2) + + user_prompt = next((msg for msg in input_messages if msg["content"] == "prompt content"), None) + self.assertIsNotNone(user_prompt) + self.assertEqual(user_prompt["role"], "user") + + traceloop_input = next((msg for msg in input_messages if msg["content"] == "traceloop input"), None) + self.assertIsNotNone(traceloop_input) + self.assertEqual(traceloop_input["role"], "user") + + output_messages = event_body["output"]["messages"] + self.assertTrue(len(output_messages) >= 3) + + completion = next((msg for msg in output_messages if msg["content"] == "completion content"), None) + self.assertIsNotNone(completion) + self.assertEqual(completion["role"], "assistant") + + agent_output = next((msg for msg in output_messages if msg["content"] == "agent output"), None) + self.assertIsNotNone(agent_output) + self.assertEqual(agent_output["role"], "assistant") + + def test_emit_llo_attributes_multiple_frameworks(self): + """ + Verify a single span containing LLO attributes from multiple frameworks + (Traceloop, OpenLit, OpenInference, CrewAI) generates one consolidated event. + """ + attributes = { + "gen_ai.prompt.0.content": "Tell me about AI", + "gen_ai.prompt.0.role": "user", + "gen_ai.completion.0.content": "AI is a field of computer science...", + "gen_ai.completion.0.role": "assistant", + "traceloop.entity.input": "What is machine learning?", + "traceloop.entity.output": "Machine learning is a subset of AI...", + "gen_ai.prompt": "Explain neural networks", + "gen_ai.completion": "Neural networks are computing systems...", + "input.value": "How do transformers work?", + "output.value": "Transformers are a type of neural network architecture...", + "crewai.crew.result": "Task completed successfully", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.multi.framework" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + self.assertEqual(emitted_event.name, "test.multi.framework") + self.assertEqual(emitted_event.timestamp, span.end_time) + + event_body = emitted_event.body + self.assertIn("input", event_body) + self.assertIn("output", event_body) + + input_messages = event_body["input"]["messages"] + input_contents = [msg["content"] for msg in input_messages] + self.assertIn("Tell me about AI", input_contents) + self.assertIn("What is machine learning?", input_contents) + self.assertIn("Explain neural networks", input_contents) + self.assertIn("How do transformers work?", input_contents) + + output_messages = event_body["output"]["messages"] + output_contents = [msg["content"] for msg in output_messages] + self.assertIn("AI is a field of computer science...", output_contents) + self.assertIn("Machine learning is a subset of AI...", output_contents) + self.assertIn("Neural networks are computing systems...", output_contents) + self.assertIn("Transformers are a type of neural network architecture...", output_contents) + self.assertIn("Task completed successfully", output_contents) + + for msg in input_messages: + self.assertIn(msg["role"], ["user", "system"]) + for msg in output_messages: + self.assertEqual(msg["role"], "assistant") + + def test_emit_llo_attributes_no_llo_attributes(self): + """ + Verify _emit_llo_attributes does not emit events when span contains only non-LLO attributes. + """ + attributes = { + "normal.attribute": "value", + "another.attribute": 123, + } + + span = self._create_mock_span(attributes) + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_not_called() + + def test_emit_llo_attributes_mixed_input_output(self): + """ + Verify event generation correctly separates mixed input (system/user) and output (assistant) messages. + """ + attributes = { + "gen_ai.prompt.0.content": "system message", + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.1.content": "user message", + "gen_ai.prompt.1.role": "user", + "gen_ai.completion.0.content": "assistant response", + "gen_ai.completion.0.role": "assistant", + "input.value": "direct input", + "output.value": "direct output", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + self.assertIn("input", event_body) + self.assertIn("output", event_body) + + input_messages = event_body["input"]["messages"] + self.assertEqual(len(input_messages), 3) + + input_roles = [msg["role"] for msg in input_messages] + self.assertIn("system", input_roles) + self.assertIn("user", input_roles) + + output_messages = event_body["output"]["messages"] + self.assertEqual(len(output_messages), 2) + + for msg in output_messages: + self.assertEqual(msg["role"], "assistant") + + def test_emit_llo_attributes_with_event_timestamp(self): + """ + Verify _emit_llo_attributes uses provided event timestamp instead of span end time. + """ + attributes = { + "gen_ai.prompt": "test prompt", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + event_timestamp = 9999999999 + + self.llo_handler._emit_llo_attributes(span, attributes, event_timestamp=event_timestamp) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + self.assertEqual(emitted_event.timestamp, event_timestamp) + + def test_emit_llo_attributes_none_attributes(self): + """ + Test _emit_llo_attributes with None attributes - should return early + """ + span = self._create_mock_span({}) + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, None) + + self.event_logger_mock.emit.assert_not_called() + + def test_emit_llo_attributes_role_based_routing(self): + """ + Test role-based routing for non-standard roles + """ + attributes = { + # Standard roles - should go to their expected places + "gen_ai.prompt.0.content": "system prompt", + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.1.content": "user prompt", + "gen_ai.prompt.1.role": "user", + "gen_ai.completion.0.content": "assistant response", + "gen_ai.completion.0.role": "assistant", + # Non-standard roles - should be routed based on source + "gen_ai.prompt.2.content": "function prompt", + "gen_ai.prompt.2.role": "function", + "gen_ai.completion.1.content": "tool completion", + "gen_ai.completion.1.role": "tool", + "gen_ai.prompt.3.content": "unknown prompt", + "gen_ai.prompt.3.role": "custom_role", + "gen_ai.completion.2.content": "unknown completion", + "gen_ai.completion.2.role": "another_custom", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + # Verify event was emitted + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + + # Check input messages + input_messages = event_body["input"]["messages"] + input_contents = [msg["content"] for msg in input_messages] + + # Standard roles (system, user) should be in input + self.assertIn("system prompt", input_contents) + self.assertIn("user prompt", input_contents) + + # Non-standard roles from prompt source should be in input + self.assertIn("function prompt", input_contents) + self.assertIn("unknown prompt", input_contents) + + # Check output messages + output_messages = event_body["output"]["messages"] + output_contents = [msg["content"] for msg in output_messages] + + # Standard role (assistant) should be in output + self.assertIn("assistant response", output_contents) + + # Non-standard roles from completion source should be in output + self.assertIn("tool completion", output_contents) + self.assertIn("unknown completion", output_contents) + + def test_emit_llo_attributes_empty_messages(self): + """ + Test _emit_llo_attributes when messages list is empty after collection + """ + # Create a span with attributes that would normally match patterns but with empty content + attributes = { + "gen_ai.prompt.0.content": "", + "gen_ai.prompt.0.role": "user", + } + + span = self._create_mock_span(attributes) + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + # Mock _collect_all_llo_messages to return empty list + with patch.object(self.llo_handler, "_collect_all_llo_messages", return_value=[]): + self.llo_handler._emit_llo_attributes(span, attributes) + + # Should not emit event when no messages collected + self.event_logger_mock.emit.assert_not_called() + + def test_emit_llo_attributes_only_input_messages(self): + """ + Test event generation when only input messages are present + """ + attributes = { + "gen_ai.prompt.0.content": "system instruction", + "gen_ai.prompt.0.role": "system", + "gen_ai.prompt.1.content": "user question", + "gen_ai.prompt.1.role": "user", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + + self.assertIn("input", event_body) + self.assertNotIn("output", event_body) + + input_messages = event_body["input"]["messages"] + self.assertEqual(len(input_messages), 2) + + def test_emit_llo_attributes_only_output_messages(self): + """ + Test event generation when only output messages are present + """ + attributes = { + "gen_ai.completion.0.content": "assistant response", + "gen_ai.completion.0.role": "assistant", + "output.value": "another output", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + + self.assertNotIn("input", event_body) + self.assertIn("output", event_body) + + output_messages = event_body["output"]["messages"] + self.assertEqual(len(output_messages), 2) + + def test_emit_llo_attributes_empty_event_body(self): + """ + Test that no event is emitted when event body would be empty + """ + # Create attributes that would result in messages with empty content + attributes = { + "gen_ai.prompt.0.content": "", + "gen_ai.prompt.0.role": "user", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + # Mock _collect_all_llo_messages to return messages with empty content + with patch.object( + self.llo_handler, + "_collect_all_llo_messages", + return_value=[{"content": "", "role": "user", "source": "prompt"}], + ): + self.llo_handler._emit_llo_attributes(span, attributes) + + # Event should still be emitted as we have a message (even with empty content) + self.event_logger_mock.emit.assert_called_once() + + def test_group_messages_by_type_standard_roles(self): + """ + Test _group_messages_by_type correctly groups messages with standard roles. + """ + messages = [ + {"role": "system", "content": "System message", "source": "prompt"}, + {"role": "user", "content": "User message", "source": "prompt"}, + {"role": "assistant", "content": "Assistant message", "source": "completion"}, + ] + + result = self.llo_handler._group_messages_by_type(messages) + + self.assertIn("input", result) + self.assertIn("output", result) + + # Check input messages + self.assertEqual(len(result["input"]), 2) + self.assertEqual(result["input"][0], {"role": "system", "content": "System message"}) + self.assertEqual(result["input"][1], {"role": "user", "content": "User message"}) + + # Check output messages + self.assertEqual(len(result["output"]), 1) + self.assertEqual(result["output"][0], {"role": "assistant", "content": "Assistant message"}) + + def test_group_messages_by_type_non_standard_roles(self): + """ + Test _group_messages_by_type correctly routes non-standard roles based on source. + """ + messages = [ + {"role": "function", "content": "Function call", "source": "prompt"}, + {"role": "tool", "content": "Tool result", "source": "completion"}, + {"role": "custom", "content": "Custom output", "source": "output"}, + {"role": "other", "content": "Other result", "source": "result"}, + ] + + result = self.llo_handler._group_messages_by_type(messages) + + # Non-standard roles from prompt source go to input + self.assertEqual(len(result["input"]), 1) + self.assertEqual(result["input"][0], {"role": "function", "content": "Function call"}) + + # Non-standard roles from completion/output/result sources go to output + self.assertEqual(len(result["output"]), 3) + output_contents = [msg["content"] for msg in result["output"]] + self.assertIn("Tool result", output_contents) + self.assertIn("Custom output", output_contents) + self.assertIn("Other result", output_contents) + + def test_group_messages_by_type_empty_list(self): + """ + Test _group_messages_by_type handles empty message list. + """ + result = self.llo_handler._group_messages_by_type([]) + + self.assertEqual(result, {"input": [], "output": []}) + self.assertEqual(len(result["input"]), 0) + self.assertEqual(len(result["output"]), 0) + + def test_group_messages_by_type_missing_fields(self): + """ + Test _group_messages_by_type handles messages with missing role or content. + """ + messages = [ + {"content": "No role", "source": "prompt"}, # Missing role + {"role": "user", "source": "prompt"}, # Missing content + {"role": "assistant", "content": "Complete message", "source": "completion"}, + ] + + result = self.llo_handler._group_messages_by_type(messages) + + # Message without role gets "unknown" role and goes to input (no completion/output/result in source) + self.assertEqual(len(result["input"]), 2) + self.assertEqual(result["input"][0], {"role": "unknown", "content": "No role"}) + self.assertEqual(result["input"][1], {"role": "user", "content": ""}) + + # Complete message goes to output + self.assertEqual(len(result["output"]), 1) + self.assertEqual(result["output"][0], {"role": "assistant", "content": "Complete message"}) + + def test_emit_llo_attributes_with_llm_prompts(self): + """ + Test that llm.prompts attribute is properly emitted in the input section. + """ + llm_prompts_content = "[{'role': 'system', 'content': [{'text': 'You are helpful.', 'type': 'text'}]}]" + attributes = { + "llm.prompts": llm_prompts_content, + "gen_ai.completion.0.content": "I understand.", + "gen_ai.completion.0.role": "assistant", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + + # Check that llm.prompts is in input section + self.assertIn("input", event_body) + self.assertIn("output", event_body) + + input_messages = event_body["input"]["messages"] + self.assertEqual(len(input_messages), 1) + self.assertEqual(input_messages[0]["content"], llm_prompts_content) + self.assertEqual(input_messages[0]["role"], "user") + + # Check output section has the completion + output_messages = event_body["output"]["messages"] + self.assertEqual(len(output_messages), 1) + self.assertEqual(output_messages[0]["content"], "I understand.") + self.assertEqual(output_messages[0]["role"], "assistant") + + def test_emit_llo_attributes_openlit_style_events(self): + """ + Test that LLO attributes from OpenLit-style span events are collected and emitted + in a single consolidated event, not as separate events. + """ + # This test simulates the OpenLit pattern where prompt and completion are in span events + # The span processor should collect from both and emit a single event + + span_attributes = {"normal.attribute": "value"} + + # Create events like OpenLit does + prompt_event_attrs = {"gen_ai.prompt": "Explain quantum computing"} + prompt_event = MagicMock(attributes=prompt_event_attrs, timestamp=1234567890) + + completion_event_attrs = {"gen_ai.completion": "Quantum computing is..."} + completion_event = MagicMock(attributes=completion_event_attrs, timestamp=1234567891) + + span = self._create_mock_span(span_attributes) + span.events = [prompt_event, completion_event] + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "openlit.otel.tracing" + + # Process the span (this would normally be called by process_spans) + all_llo_attrs = {} + + # Collect from span attributes + for key, value in span_attributes.items(): + if self.llo_handler._is_llo_attribute(key): + all_llo_attrs[key] = value + + # Collect from events + for event in span.events: + if event.attributes: + for key, value in event.attributes.items(): + if self.llo_handler._is_llo_attribute(key): + all_llo_attrs[key] = value + + # Emit consolidated event + self.llo_handler._emit_llo_attributes(span, all_llo_attrs) + + # Verify single event was emitted with both input and output + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + event_body = emitted_event.body + + # Both input and output should be in the same event + self.assertIn("input", event_body) + self.assertIn("output", event_body) + + # Check input section + input_messages = event_body["input"]["messages"] + self.assertEqual(len(input_messages), 1) + self.assertEqual(input_messages[0]["content"], "Explain quantum computing") + self.assertEqual(input_messages[0]["role"], "user") + + # Check output section + output_messages = event_body["output"]["messages"] + self.assertEqual(len(output_messages), 1) + self.assertEqual(output_messages[0]["content"], "Quantum computing is...") + self.assertEqual(output_messages[0]["role"], "assistant") + + def test_emit_llo_attributes_with_session_id(self): + """ + Verify session.id attribute from span is copied to event attributes when present. + """ + attributes = { + "session.id": "test-session-123", + "gen_ai.prompt": "Hello, AI", + "gen_ai.completion": "Hello! How can I help you?", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + # Verify session.id was copied to event attributes + self.assertIsNotNone(emitted_event.attributes) + self.assertEqual(emitted_event.attributes.get("session.id"), "test-session-123") + # Event class always adds event.name + self.assertIn("event.name", emitted_event.attributes) + + # Verify event body still contains LLO data + event_body = emitted_event.body + self.assertIn("input", event_body) + self.assertIn("output", event_body) + + def test_emit_llo_attributes_without_session_id(self): + """ + Verify event attributes do not contain session.id when not present in span attributes. + """ + attributes = { + "gen_ai.prompt": "Hello, AI", + "gen_ai.completion": "Hello! How can I help you?", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + # Verify session.id is not in event attributes + self.assertIsNotNone(emitted_event.attributes) + self.assertNotIn("session.id", emitted_event.attributes) + # Event class always adds event.name + self.assertIn("event.name", emitted_event.attributes) + + def test_emit_llo_attributes_with_session_id_and_other_attributes(self): + """ + Verify only session.id is copied from span attributes when mixed with other attributes. + """ + attributes = { + "session.id": "session-456", + "user.id": "user-789", + "gen_ai.prompt": "What's the weather?", + "gen_ai.completion": "I can't check the weather.", + "other.attribute": "some-value", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._emit_llo_attributes(span, attributes) + + self.event_logger_mock.emit.assert_called_once() + emitted_event = self.event_logger_mock.emit.call_args[0][0] + + # Verify only session.id was copied to event attributes (plus event.name from Event class) + self.assertIsNotNone(emitted_event.attributes) + self.assertEqual(emitted_event.attributes.get("session.id"), "session-456") + self.assertIn("event.name", emitted_event.attributes) + # Verify other span attributes were not copied + self.assertNotIn("user.id", emitted_event.attributes) + self.assertNotIn("other.attribute", emitted_event.attributes) + self.assertNotIn("gen_ai.prompt", emitted_event.attributes) + self.assertNotIn("gen_ai.completion", emitted_event.attributes) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_frameworks.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_frameworks.py new file mode 100644 index 000000000..5dfc069b9 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_frameworks.py @@ -0,0 +1,444 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LLO Handler framework-specific functionality.""" + +from unittest.mock import MagicMock + +from test_llo_handler_base import LLOHandlerTestBase + + +class TestLLOHandlerFrameworks(LLOHandlerTestBase): + """Test framework-specific LLO attribute handling.""" + + def test_collect_traceloop_messages(self): + """ + Verify Traceloop entity input/output attributes are collected with correct roles + (input->user, output->assistant). + """ + attributes = { + "traceloop.entity.input": "input data", + "traceloop.entity.output": "output data", + "traceloop.entity.name": "my_entity", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + traceloop_messages = [m for m in messages if m["source"] in ["input", "output"]] + + self.assertEqual(len(traceloop_messages), 2) + + input_message = traceloop_messages[0] + self.assertEqual(input_message["content"], "input data") + self.assertEqual(input_message["role"], "user") + self.assertEqual(input_message["source"], "input") + + output_message = traceloop_messages[1] + self.assertEqual(output_message["content"], "output data") + self.assertEqual(output_message["role"], "assistant") + self.assertEqual(output_message["source"], "output") + + def test_collect_traceloop_messages_all_attributes(self): + """ + Verify collection of mixed Traceloop and CrewAI attributes, ensuring all are collected + with appropriate roles and sources. + """ + attributes = { + "traceloop.entity.input": "input data", + "traceloop.entity.output": "output data", + "crewai.crew.tasks_output": "[TaskOutput(description='Task 1', output='Result 1')]", + "crewai.crew.result": "Final crew result", + "traceloop.entity.name": "crewai_agent", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 4) + + self.assertEqual(messages[0]["content"], "input data") + self.assertEqual(messages[0]["role"], "user") + self.assertEqual(messages[0]["source"], "input") + + self.assertEqual(messages[1]["content"], "output data") + self.assertEqual(messages[1]["role"], "assistant") + self.assertEqual(messages[1]["source"], "output") + + self.assertEqual(messages[2]["content"], "[TaskOutput(description='Task 1', output='Result 1')]") + self.assertEqual(messages[2]["role"], "assistant") + self.assertEqual(messages[2]["source"], "output") + + self.assertEqual(messages[3]["content"], "Final crew result") + self.assertEqual(messages[3]["role"], "assistant") + self.assertEqual(messages[3]["source"], "result") + + def test_collect_openlit_messages_direct_prompt(self): + """ + Verify OpenLit's direct gen_ai.prompt attribute is collected with user role and prompt source. + """ + attributes = {"gen_ai.prompt": "user direct prompt"} + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "user direct prompt") + self.assertEqual(message["role"], "user") + self.assertEqual(message["source"], "prompt") + + def test_collect_openlit_messages_direct_completion(self): + """ + Verify OpenLit's direct gen_ai.completion attribute is collected with assistant role and completion source. + """ + attributes = {"gen_ai.completion": "assistant direct completion"} + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "assistant direct completion") + self.assertEqual(message["role"], "assistant") + self.assertEqual(message["source"], "completion") + + def test_collect_openlit_messages_all_attributes(self): + """ + Verify all OpenLit framework attributes (prompt, completion, revised_prompt, agent.*) + are collected with correct roles and sources. + """ + attributes = { + "gen_ai.prompt": "user prompt", + "gen_ai.completion": "assistant response", + "gen_ai.content.revised_prompt": "revised prompt", + "gen_ai.agent.actual_output": "agent output", + "gen_ai.agent.human_input": "human input to agent", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 5) + + self.assertEqual(messages[0]["content"], "user prompt") + self.assertEqual(messages[0]["role"], "user") + self.assertEqual(messages[0]["source"], "prompt") + + self.assertEqual(messages[1]["content"], "assistant response") + self.assertEqual(messages[1]["role"], "assistant") + self.assertEqual(messages[1]["source"], "completion") + + self.assertEqual(messages[2]["content"], "revised prompt") + self.assertEqual(messages[2]["role"], "system") + self.assertEqual(messages[2]["source"], "prompt") + + self.assertEqual(messages[3]["content"], "agent output") + self.assertEqual(messages[3]["role"], "assistant") + self.assertEqual(messages[3]["source"], "output") + + self.assertEqual(messages[4]["content"], "human input to agent") + self.assertEqual(messages[4]["role"], "user") + self.assertEqual(messages[4]["source"], "input") + + def test_collect_openlit_messages_revised_prompt(self): + """ + Verify OpenLit's gen_ai.content.revised_prompt is collected with system role and prompt source. + """ + attributes = {"gen_ai.content.revised_prompt": "revised system prompt"} + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "revised system prompt") + self.assertEqual(message["role"], "system") + self.assertEqual(message["source"], "prompt") + + def test_collect_openinference_messages_direct_attributes(self): + """ + Verify OpenInference's direct input.value and output.value attributes are collected + with appropriate roles (user/assistant) and sources. + """ + attributes = { + "input.value": "user prompt", + "output.value": "assistant response", + "llm.model_name": "gpt-4", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + input_message = messages[0] + self.assertEqual(input_message["content"], "user prompt") + self.assertEqual(input_message["role"], "user") + self.assertEqual(input_message["source"], "input") + + output_message = messages[1] + self.assertEqual(output_message["content"], "assistant response") + self.assertEqual(output_message["role"], "assistant") + self.assertEqual(output_message["source"], "output") + + def test_collect_openinference_messages_structured_input(self): + """ + Verify OpenInference's indexed llm.input_messages.{n}.message.content attributes + are collected with roles from corresponding role attributes. + """ + attributes = { + "llm.input_messages.0.message.content": "system prompt", + "llm.input_messages.0.message.role": "system", + "llm.input_messages.1.message.content": "user message", + "llm.input_messages.1.message.role": "user", + "llm.model_name": "claude-3", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + system_message = messages[0] + self.assertEqual(system_message["content"], "system prompt") + self.assertEqual(system_message["role"], "system") + self.assertEqual(system_message["source"], "input") + + user_message = messages[1] + self.assertEqual(user_message["content"], "user message") + self.assertEqual(user_message["role"], "user") + self.assertEqual(user_message["source"], "input") + + def test_collect_openinference_messages_structured_output(self): + """ + Verify OpenInference's indexed llm.output_messages.{n}.message.content attributes + are collected with source='output' and roles from corresponding attributes. + """ + attributes = { + "llm.output_messages.0.message.content": "assistant response", + "llm.output_messages.0.message.role": "assistant", + "llm.model_name": "llama-3", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + + output_message = messages[0] + self.assertEqual(output_message["content"], "assistant response") + self.assertEqual(output_message["role"], "assistant") + self.assertEqual(output_message["source"], "output") + + def test_collect_openinference_messages_mixed_attributes(self): + """ + Verify mixed OpenInference attributes (direct and indexed) are all collected + and maintain correct roles and counts. + """ + attributes = { + "input.value": "direct input", + "output.value": "direct output", + "llm.input_messages.0.message.content": "message input", + "llm.input_messages.0.message.role": "user", + "llm.output_messages.0.message.content": "message output", + "llm.output_messages.0.message.role": "assistant", + "llm.model_name": "bedrock.claude-3", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 4) + + contents = [msg["content"] for msg in messages] + self.assertIn("direct input", contents) + self.assertIn("direct output", contents) + self.assertIn("message input", contents) + self.assertIn("message output", contents) + + roles = [msg["role"] for msg in messages] + self.assertEqual(roles.count("user"), 2) + self.assertEqual(roles.count("assistant"), 2) + + def test_collect_openlit_messages_agent_actual_output(self): + """ + Verify OpenLit's gen_ai.agent.actual_output is collected with assistant role and output source. + """ + attributes = {"gen_ai.agent.actual_output": "Agent task output result"} + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + + message = messages[0] + self.assertEqual(message["content"], "Agent task output result") + self.assertEqual(message["role"], "assistant") + self.assertEqual(message["source"], "output") + + def test_collect_openlit_messages_agent_human_input(self): + """ + Verify OpenLit's gen_ai.agent.human_input is collected with user role and input source. + """ + attributes = {"gen_ai.agent.human_input": "Human input to the agent"} + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], "Human input to the agent") + self.assertEqual(message["role"], "user") + self.assertEqual(message["source"], "input") + + def test_collect_traceloop_messages_crew_outputs(self): + """ + Verify CrewAI-specific attributes (tasks_output, result) are collected with assistant role + and appropriate sources. + """ + attributes = { + "crewai.crew.tasks_output": "[TaskOutput(description='Task description', output='Task result')]", + "crewai.crew.result": "Final crew execution result", + "traceloop.entity.name": "crewai", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + tasks_message = messages[0] + self.assertEqual(tasks_message["content"], "[TaskOutput(description='Task description', output='Task result')]") + self.assertEqual(tasks_message["role"], "assistant") + self.assertEqual(tasks_message["source"], "output") + + result_message = messages[1] + self.assertEqual(result_message["content"], "Final crew execution result") + self.assertEqual(result_message["role"], "assistant") + self.assertEqual(result_message["source"], "result") + + def test_openinference_messages_with_default_roles(self): + """ + Verify OpenInference indexed messages use default roles (user for input, assistant for output) + when role attributes are missing. + """ + attributes = { + "llm.input_messages.0.message.content": "input without role", + "llm.output_messages.0.message.content": "output without role", + } + + span = self._create_mock_span(attributes) + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + input_msg = next((m for m in messages if m["content"] == "input without role"), None) + self.assertIsNotNone(input_msg) + self.assertEqual(input_msg["role"], "user") + self.assertEqual(input_msg["source"], "input") + + output_msg = next((m for m in messages if m["content"] == "output without role"), None) + self.assertIsNotNone(output_msg) + self.assertEqual(output_msg["role"], "assistant") + self.assertEqual(output_msg["source"], "output") + + def test_collect_strands_sdk_messages(self): + """ + Verify Strands SDK patterns (system_prompt, tool.result) are collected + with correct roles and sources. + """ + attributes = { + "system_prompt": "You are a helpful assistant", + "tool.result": "Tool execution completed successfully", + } + + span = self._create_mock_span(attributes) + span.end_time = 1234567899 + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "strands.sdk" + + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 2) + + system_msg = next((m for m in messages if m["content"] == "You are a helpful assistant"), None) + self.assertIsNotNone(system_msg) + self.assertEqual(system_msg["role"], "system") + self.assertEqual(system_msg["source"], "prompt") + + tool_msg = next((m for m in messages if m["content"] == "Tool execution completed successfully"), None) + self.assertIsNotNone(tool_msg) + self.assertEqual(tool_msg["role"], "assistant") + self.assertEqual(tool_msg["source"], "output") + + def test_collect_llm_prompts_messages(self): + """ + Verify llm.prompts attribute is collected as a user message with prompt source. + """ + attributes = { + "llm.prompts": ( + "[{'role': 'system', 'content': [{'text': 'You are a helpful AI assistant.', 'type': 'text'}]}, " + "{'role': 'user', 'content': [{'text': 'What are the benefits of using FastAPI?', 'type': 'text'}]}]" + ), + "other.attribute": "not collected", + } + + span = self._create_mock_span(attributes) + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 1) + message = messages[0] + self.assertEqual(message["content"], attributes["llm.prompts"]) + self.assertEqual(message["role"], "user") + self.assertEqual(message["source"], "prompt") + + def test_collect_llm_prompts_with_other_messages(self): + """ + Verify llm.prompts works correctly alongside other LLO attributes. + """ + attributes = { + "llm.prompts": "[{'role': 'system', 'content': 'System prompt'}]", + "gen_ai.prompt": "Direct prompt", + "gen_ai.completion": "Assistant response", + } + + span = self._create_mock_span(attributes) + messages = self.llo_handler._collect_all_llo_messages(span, attributes) + + self.assertEqual(len(messages), 3) + + # Check llm.prompts message + llm_prompts_msg = next((m for m in messages if m["content"] == attributes["llm.prompts"]), None) + self.assertIsNotNone(llm_prompts_msg) + self.assertEqual(llm_prompts_msg["role"], "user") + self.assertEqual(llm_prompts_msg["source"], "prompt") + + # Check other messages are still collected + direct_prompt_msg = next((m for m in messages if m["content"] == "Direct prompt"), None) + self.assertIsNotNone(direct_prompt_msg) + + completion_msg = next((m for m in messages if m["content"] == "Assistant response"), None) + self.assertIsNotNone(completion_msg) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_patterns.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_patterns.py new file mode 100644 index 000000000..25abfcca6 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_patterns.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LLO Handler pattern matching functionality.""" + +from test_llo_handler_base import LLOHandlerTestBase + +from amazon.opentelemetry.distro.llo_handler import LLO_PATTERNS, LLOHandler, PatternType + + +class TestLLOHandlerPatterns(LLOHandlerTestBase): + """Test pattern matching and recognition functionality.""" + + def test_init(self): + """ + Verify LLOHandler initializes correctly with logger provider and creates event logger provider. + """ + self.assertEqual(self.llo_handler._logger_provider, self.logger_provider_mock) + self.assertEqual(self.llo_handler._event_logger_provider, self.event_logger_provider_mock) + + def test_is_llo_attribute_match(self): + """ + Verify _is_llo_attribute correctly identifies indexed Gen AI prompt patterns (gen_ai.prompt.{n}.content). + """ + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.prompt.0.content")) + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.prompt.123.content")) + + def test_is_llo_attribute_no_match(self): + """ + Verify _is_llo_attribute correctly rejects malformed patterns and non-LLO attributes. + """ + self.assertFalse(self.llo_handler._is_llo_attribute("gen_ai.prompt.content")) + self.assertFalse(self.llo_handler._is_llo_attribute("gen_ai.prompt.abc.content")) + self.assertFalse(self.llo_handler._is_llo_attribute("some.other.attribute")) + + def test_is_llo_attribute_traceloop_match(self): + """ + Verify _is_llo_attribute recognizes Traceloop framework patterns (traceloop.entity.input/output). + """ + self.assertTrue(self.llo_handler._is_llo_attribute("traceloop.entity.input")) + self.assertTrue(self.llo_handler._is_llo_attribute("traceloop.entity.output")) + + def test_is_llo_attribute_openlit_match(self): + """ + Verify _is_llo_attribute recognizes OpenLit framework patterns (gen_ai.prompt, gen_ai.completion, etc.). + """ + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.prompt")) + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.completion")) + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.content.revised_prompt")) + + def test_is_llo_attribute_openinference_match(self): + """ + Verify _is_llo_attribute recognizes OpenInference patterns including both direct (input/output.value) + and indexed (llm.input_messages.{n}.message.content) patterns. + """ + self.assertTrue(self.llo_handler._is_llo_attribute("input.value")) + self.assertTrue(self.llo_handler._is_llo_attribute("output.value")) + self.assertTrue(self.llo_handler._is_llo_attribute("llm.input_messages.0.message.content")) + self.assertTrue(self.llo_handler._is_llo_attribute("llm.output_messages.123.message.content")) + + def test_is_llo_attribute_crewai_match(self): + """ + Verify _is_llo_attribute recognizes CrewAI framework patterns (gen_ai.agent.*, crewai.crew.*). + """ + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.agent.actual_output")) + self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.agent.human_input")) + self.assertTrue(self.llo_handler._is_llo_attribute("crewai.crew.tasks_output")) + self.assertTrue(self.llo_handler._is_llo_attribute("crewai.crew.result")) + + def test_is_llo_attribute_strands_sdk_match(self): + """ + Verify _is_llo_attribute recognizes Strands SDK patterns (system_prompt, tool.result). + """ + self.assertTrue(self.llo_handler._is_llo_attribute("system_prompt")) + self.assertTrue(self.llo_handler._is_llo_attribute("tool.result")) + + def test_is_llo_attribute_llm_prompts_match(self): + """ + Verify _is_llo_attribute recognizes llm.prompts pattern. + """ + self.assertTrue(self.llo_handler._is_llo_attribute("llm.prompts")) + + def test_build_pattern_matchers_with_missing_regex(self): + """ + Test _build_pattern_matchers handles patterns with missing regex gracefully + """ + # Temporarily modify LLO_PATTERNS to have a pattern without regex + original_patterns = LLO_PATTERNS.copy() + + # Add a malformed indexed pattern without regex + LLO_PATTERNS["test.bad.pattern"] = { + "type": PatternType.INDEXED, + # Missing "regex" key + "role_key": "test.bad.pattern.role", + "default_role": "unknown", + "source": "test", + } + + try: + # Create a new handler to trigger pattern building + handler = LLOHandler(self.logger_provider_mock) + + # Should handle gracefully - the bad pattern should be skipped + self.assertNotIn("test.bad.pattern", handler._pattern_configs) + + # Other patterns should still work + self.assertTrue(handler._is_llo_attribute("gen_ai.prompt")) + self.assertFalse(handler._is_llo_attribute("test.bad.pattern")) + + finally: + # Restore original patterns + LLO_PATTERNS.clear() + LLO_PATTERNS.update(original_patterns) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_processing.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_processing.py new file mode 100644 index 000000000..d76699849 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/llo_handler/test_llo_handler_processing.py @@ -0,0 +1,328 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LLO Handler span and attribute processing functionality.""" + +from unittest.mock import MagicMock, patch + +from test_llo_handler_base import LLOHandlerTestBase + +from opentelemetry.attributes import BoundedAttributes +from opentelemetry.sdk.trace import Event as SpanEvent + + +class TestLLOHandlerProcessing(LLOHandlerTestBase): + """Test span processing and attribute filtering functionality.""" + + def test_filter_attributes(self): + """ + Verify _filter_attributes removes LLO content attributes while preserving role attributes + and other non-LLO attributes. + """ + attributes = { + "gen_ai.prompt.0.content": "test content", + "gen_ai.prompt.0.role": "user", + "normal.attribute": "value", + "another.normal.attribute": 123, + } + + filtered = self.llo_handler._filter_attributes(attributes) + + self.assertNotIn("gen_ai.prompt.0.content", filtered) + self.assertIn("gen_ai.prompt.0.role", filtered) + self.assertIn("normal.attribute", filtered) + self.assertIn("another.normal.attribute", filtered) + + def test_filter_attributes_empty_dict(self): + """ + Verify _filter_attributes returns empty dict when given empty dict. + """ + result = self.llo_handler._filter_attributes({}) + + self.assertEqual(result, {}) + + def test_filter_attributes_none_handling(self): + """ + Verify _filter_attributes returns original attributes when no LLO attributes are present. + """ + attributes = {"normal.attr": "value"} + result = self.llo_handler._filter_attributes(attributes) + + self.assertEqual(result, attributes) + + def test_filter_attributes_no_llo_attrs(self): + """ + Test _filter_attributes when there are no LLO attributes - should return original + """ + attributes = { + "normal.attr1": "value1", + "normal.attr2": "value2", + "other.attribute": "value", # This is not an LLO attribute + } + + result = self.llo_handler._filter_attributes(attributes) + + # Should return the same attributes object when no LLO attrs present + self.assertIs(result, attributes) + self.assertEqual(result, attributes) + + def test_process_spans(self): + """ + Verify process_spans extracts LLO attributes, emits events, filters attributes, + and processes span events correctly. + """ + attributes = {"gen_ai.prompt.0.content": "prompt content", "normal.attribute": "normal value"} + + span = self._create_mock_span(attributes) + span.events = [] + + with patch.object(self.llo_handler, "_emit_llo_attributes") as mock_emit, patch.object( + self.llo_handler, "_filter_attributes" + ) as mock_filter: + + filtered_attributes = {"normal.attribute": "normal value"} + mock_filter.return_value = filtered_attributes + + result = self.llo_handler.process_spans([span]) + + # Now it's called with only the LLO attributes + expected_llo_attrs = {"gen_ai.prompt.0.content": "prompt content"} + mock_emit.assert_called_once_with(span, expected_llo_attrs) + mock_filter.assert_called_once_with(attributes) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0], span) + self.assertEqual(result[0]._attributes, filtered_attributes) + + def test_process_spans_with_bounded_attributes(self): + """ + Verify process_spans correctly handles spans with BoundedAttributes, + preserving attribute limits and settings. + """ + bounded_attrs = BoundedAttributes( + maxlen=10, + attributes={"gen_ai.prompt.0.content": "prompt content", "normal.attribute": "normal value"}, + immutable=False, + max_value_len=1000, + ) + + span = self._create_mock_span(bounded_attrs) + span.events = [] # Add empty events list + + with patch.object(self.llo_handler, "_emit_llo_attributes") as mock_emit, patch.object( + self.llo_handler, "_filter_attributes" + ) as mock_filter: + + filtered_attributes = {"normal.attribute": "normal value"} + mock_filter.return_value = filtered_attributes + + result = self.llo_handler.process_spans([span]) + + # Now it's called with only the LLO attributes + expected_llo_attrs = {"gen_ai.prompt.0.content": "prompt content"} + mock_emit.assert_called_once_with(span, expected_llo_attrs) + mock_filter.assert_called_once_with(bounded_attrs) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0], span) + self.assertIsInstance(result[0]._attributes, BoundedAttributes) + self.assertEqual(dict(result[0]._attributes), filtered_attributes) + + def test_process_spans_none_attributes(self): + """ + Verify process_spans correctly handles spans with None attributes. + """ + span = self._create_mock_span(None, preserve_none=True) + span.events = [] + + result = self.llo_handler.process_spans([span]) + + self.assertEqual(len(result), 1) + self.assertIsNone(result[0]._attributes) + + def test_filter_span_events(self): + """ + Verify _filter_span_events filters LLO attributes from span events correctly. + """ + event_attributes = { + "gen_ai.prompt": "event prompt", + "normal.attribute": "keep this", + } + + event = SpanEvent( + name="test_event", + attributes=event_attributes, + timestamp=1234567890, + ) + + span = self._create_mock_span({}) + span.events = [event] + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._filter_span_events(span) + + span_events = getattr(span, "_events", []) + updated_event = span_events[0] + self.assertIn("normal.attribute", updated_event.attributes) + self.assertNotIn("gen_ai.prompt", updated_event.attributes) + + def test_filter_span_events_no_events(self): + """ + Verify _filter_span_events handles spans with no events gracefully. + """ + span = self._create_mock_span({}) + span.events = None + span._events = None + + self.llo_handler._filter_span_events(span) + + self.assertIsNone(span._events) + + def test_filter_span_events_no_attributes(self): + """ + Test _filter_span_events when event has no attributes + """ + event = SpanEvent( + name="test_event", + attributes=None, + timestamp=1234567890, + ) + + span = self._create_mock_span({}) + span.events = [event] + + self.llo_handler._filter_span_events(span) + + # Should handle gracefully and keep the original event + span_events = getattr(span, "_events", []) + self.assertEqual(len(span_events), 1) + self.assertEqual(span_events[0], event) + + def test_filter_span_events_bounded_attributes(self): + """ + Test _filter_span_events with BoundedAttributes in events + """ + bounded_event_attrs = BoundedAttributes( + maxlen=5, + attributes={ + "gen_ai.prompt": "event prompt", + "normal.attribute": "keep this", + }, + immutable=False, + max_value_len=100, + ) + + event = SpanEvent( + name="test_event", + attributes=bounded_event_attrs, + timestamp=1234567890, + limit=5, + ) + + span = self._create_mock_span({}) + span.events = [event] + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "test.scope" + + self.llo_handler._filter_span_events(span) + + # Verify event was updated with filtered attributes + span_events = getattr(span, "_events", []) + updated_event = span_events[0] + self.assertIsInstance(updated_event, SpanEvent) + self.assertEqual(updated_event.name, "test_event") + self.assertIn("normal.attribute", updated_event.attributes) + self.assertNotIn("gen_ai.prompt", updated_event.attributes) + + def test_process_spans_consolidated_event_emission(self): + """ + Verify process_spans collects LLO attributes from both span attributes and events, + then emits a single consolidated event. + """ + # Span attributes with prompt + span_attributes = { + "gen_ai.prompt": "What is quantum computing?", + "normal.attribute": "keep this", + } + + # Event attributes with completion + event_attributes = { + "gen_ai.completion": "Quantum computing is...", + "other.attribute": "also keep this", + } + + event = SpanEvent( + name="gen_ai.content.completion", + attributes=event_attributes, + timestamp=1234567890, + ) + + span = self._create_mock_span(span_attributes) + span.events = [event] + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "openlit.otel.tracing" + + with patch.object(self.llo_handler, "_emit_llo_attributes") as mock_emit: + result = self.llo_handler.process_spans([span]) + + # Should emit once with combined attributes + mock_emit.assert_called_once() + call_args = mock_emit.call_args[0] + emitted_span = call_args[0] + emitted_attributes = call_args[1] + + # Verify the emitted attributes contain both prompt and completion + self.assertEqual(emitted_span, span) + self.assertIn("gen_ai.prompt", emitted_attributes) + self.assertIn("gen_ai.completion", emitted_attributes) + self.assertEqual(emitted_attributes["gen_ai.prompt"], "What is quantum computing?") + self.assertEqual(emitted_attributes["gen_ai.completion"], "Quantum computing is...") + + # Verify span attributes are filtered + self.assertNotIn("gen_ai.prompt", result[0]._attributes) + self.assertIn("normal.attribute", result[0]._attributes) + + # Verify event attributes are filtered + updated_event = result[0]._events[0] + self.assertNotIn("gen_ai.completion", updated_event.attributes) + self.assertIn("other.attribute", updated_event.attributes) + + def test_process_spans_multiple_events_consolidated(self): + """ + Verify process_spans handles multiple events correctly, collecting all LLO attributes + into a single consolidated event. + """ + span_attributes = {"normal.attribute": "keep this"} + + # First event with prompt + event1_attrs = {"gen_ai.prompt": "First question"} + event1 = SpanEvent( + name="gen_ai.content.prompt", + attributes=event1_attrs, + timestamp=1234567890, + ) + + # Second event with completion + event2_attrs = {"gen_ai.completion": "First answer"} + event2 = SpanEvent( + name="gen_ai.content.completion", + attributes=event2_attrs, + timestamp=1234567891, + ) + + span = self._create_mock_span(span_attributes) + span.events = [event1, event2] + span.instrumentation_scope = MagicMock() + span.instrumentation_scope.name = "openlit.otel.tracing" + + with patch.object(self.llo_handler, "_emit_llo_attributes") as mock_emit: + self.llo_handler.process_spans([span]) + + # Should emit once with attributes from both events + mock_emit.assert_called_once() + emitted_attributes = mock_emit.call_args[0][1] + + self.assertIn("gen_ai.prompt", emitted_attributes) + self.assertIn("gen_ai.completion", emitted_attributes) + self.assertEqual(emitted_attributes["gen_ai.prompt"], "First question") + self.assertEqual(emitted_attributes["gen_ai.completion"], "First answer") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_auth_session.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_auth_session.py deleted file mode 100644 index e0c62b89d..000000000 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_auth_session.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from unittest import TestCase -from unittest.mock import patch - -import requests -from botocore.credentials import Credentials - -from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession - -AWS_OTLP_TRACES_ENDPOINT = "https://xray.us-east-1.amazonaws.com/v1/traces" -AWS_OTLP_LOGS_ENDPOINT = "https://logs.us-east-1.amazonaws.com/v1/logs" - -AUTHORIZATION_HEADER = "Authorization" -X_AMZ_DATE_HEADER = "X-Amz-Date" -X_AMZ_SECURITY_TOKEN_HEADER = "X-Amz-Security-Token" - -mock_credentials = Credentials(access_key="test_access_key", secret_key="test_secret_key", token="test_session_token") - - -class TestAwsAuthSession(TestCase): - @patch("pkg_resources.get_distribution", side_effect=ImportError("test error")) - @patch.dict("sys.modules", {"botocore": None}, clear=False) - @patch("requests.Session.request", return_value=requests.Response()) - def test_aws_auth_session_no_botocore(self, _, __): - """Tests that aws_auth_session will not inject SigV4 Headers if botocore is not installed.""" - - session = AwsAuthSession("us-east-1", "xray") - actual_headers = {"test": "test"} - - session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) - - self.assertNotIn(AUTHORIZATION_HEADER, actual_headers) - self.assertNotIn(X_AMZ_DATE_HEADER, actual_headers) - self.assertNotIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) - - @patch("requests.Session.request", return_value=requests.Response()) - @patch("botocore.session.Session.get_credentials", return_value=None) - def test_aws_auth_session_no_credentials(self, _, __): - """Tests that aws_auth_session will not inject SigV4 Headers if retrieving credentials returns None.""" - - session = AwsAuthSession("us-east-1", "xray") - actual_headers = {"test": "test"} - - session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) - - self.assertNotIn(AUTHORIZATION_HEADER, actual_headers) - self.assertNotIn(X_AMZ_DATE_HEADER, actual_headers) - self.assertNotIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) - - @patch("requests.Session.request", return_value=requests.Response()) - @patch("botocore.session.Session.get_credentials", return_value=mock_credentials) - def test_aws_auth_session(self, _, __): - """Tests that aws_auth_session will inject SigV4 Headers if botocore is installed.""" - - session = AwsAuthSession("us-east-1", "xray") - actual_headers = {"test": "test"} - - session.request("POST", AWS_OTLP_TRACES_ENDPOINT, data="", headers=actual_headers) - - self.assertIn(AUTHORIZATION_HEADER, actual_headers) - self.assertIn(X_AMZ_DATE_HEADER, actual_headers) - self.assertIn(X_AMZ_SECURITY_TOKEN_HEADER, actual_headers) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_metric_attribute_generator.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_metric_attribute_generator.py index d122519cf..f99b0d154 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_metric_attribute_generator.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_metric_attribute_generator.py @@ -37,12 +37,12 @@ AWS_STEPFUNCTIONS_STATEMACHINE_ARN, ) from amazon.opentelemetry.distro._aws_metric_attribute_generator import _AwsMetricAttributeGenerator -from amazon.opentelemetry.distro._aws_span_processing_util import GEN_AI_REQUEST_MODEL from amazon.opentelemetry.distro.metric_attribute_generator import DEPENDENCY_METRIC, SERVICE_METRIC from opentelemetry.attributes import BoundedAttributes from opentelemetry.sdk.resources import _DEFAULT_RESOURCE, SERVICE_NAME from opentelemetry.sdk.trace import ReadableSpan, Resource from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_REQUEST_MODEL from opentelemetry.semconv.trace import MessagingOperationValues, SpanAttributes from opentelemetry.trace import SpanContext, SpanKind from opentelemetry.util.types import Attributes diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py index 13397a0d5..14cb9f824 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelementry_configurator.py @@ -1,5 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-lines + import os import time from unittest import TestCase @@ -7,6 +10,7 @@ from requests import Session +from amazon.opentelemetry.distro._aws_attribute_keys import AWS_LOCAL_SERVICE, AWS_SERVICE_TYPE from amazon.opentelemetry.distro.always_record_sampler import AlwaysRecordSampler from amazon.opentelemetry.distro.attribute_propagating_span_processor import AttributePropagatingSpanProcessor from amazon.opentelemetry.distro.aws_batch_unsampled_span_processor import BatchUnsampledSpanProcessor @@ -19,22 +23,36 @@ OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, ApplicationSignalsExporterProvider, AwsOpenTelemetryConfigurator, + OtlpLogHeaderSetting, + _check_emf_exporter_enabled, + _create_aws_otlp_exporter, _custom_import_sampler, + _customize_log_record_processor, _customize_logs_exporter, _customize_metric_exporters, + _customize_resource, _customize_sampler, _customize_span_exporter, _customize_span_processors, + _export_unsampled_span_for_agent_observability, _export_unsampled_span_for_lambda, _init_logging, _is_application_signals_enabled, _is_application_signals_runtime_enabled, _is_defer_to_workers_enabled, _is_wsgi_master_process, + _validate_and_fetch_logs_header, + create_emf_exporter, ) from amazon.opentelemetry.distro.aws_opentelemetry_distro import AwsOpenTelemetryDistro from amazon.opentelemetry.distro.aws_span_metrics_processor import AwsSpanMetricsProcessor +from amazon.opentelemetry.distro.exporter.aws.metrics.aws_cloudwatch_emf_exporter import AwsCloudWatchEmfExporter from amazon.opentelemetry.distro.exporter.otlp.aws.common.aws_auth_session import AwsAuthSession + +# pylint: disable=line-too-long +from amazon.opentelemetry.distro.exporter.otlp.aws.logs._aws_cw_otlp_batch_log_record_processor import ( + AwsCloudWatchOtlpBatchLogRecordProcessor, +) from amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter import OTLPAwsLogExporter from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter from amazon.opentelemetry.distro.otlp_udp_exporter import OTLPUdpSpanExporter @@ -50,13 +68,16 @@ from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPHttpOTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.metrics import get_meter_provider from opentelemetry.processor.baggage import BaggageSpanProcessor +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor from opentelemetry.sdk.environment_variables import OTEL_TRACES_SAMPLER, OTEL_TRACES_SAMPLER_ARG from opentelemetry.sdk.metrics._internal.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, SpanProcessor, Tracer, TracerProvider from opentelemetry.sdk.trace.export import SpanExporter from opentelemetry.sdk.trace.sampling import DEFAULT_ON, Sampler +from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.trace import get_tracer_provider @@ -87,6 +108,22 @@ def setUpClass(cls): aws_otel_configurator.configure() cls.tracer_provider: TracerProvider = get_tracer_provider() + @classmethod + def tearDownClass(cls): + # Explicitly shut down meter provider to avoid I/O errors on Python 3.9 with gevent + # This ensures ConsoleMetricExporter is properly closed before Python cleanup + try: + meter_provider = get_meter_provider() + if hasattr(meter_provider, "force_flush"): + meter_provider.force_flush() + if hasattr(meter_provider, "shutdown"): + meter_provider.shutdown() + except (ValueError, RuntimeError): + # Ignore errors during cleanup: + # - ValueError: I/O operation on closed file (the exact error we're trying to prevent) + # - RuntimeError: Provider already shut down or threading issues + pass + def tearDown(self): os.environ.pop("OTEL_AWS_APPLICATION_SIGNALS_ENABLED", None) os.environ.pop("OTEL_AWS_APPLICATION_SIGNALS_RUNTIME_ENABLED", None) @@ -344,25 +381,24 @@ def test_customize_span_exporter_with_agent_observability(self): def test_customize_span_processors_with_agent_observability(self): mock_tracer_provider: TracerProvider = MagicMock() - # Test that BaggageSpanProcessor is not added when agent observability is disabled os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) _customize_span_processors(mock_tracer_provider, Resource.get_empty()) self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 0) - # Reset mock for next test mock_tracer_provider.reset_mock() - # Test that BaggageSpanProcessor is added when agent observability is enabled os.environ["AGENT_OBSERVABILITY_ENABLED"] = "true" + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "https://xray.us-east-1.amazonaws.com/v1/traces" _customize_span_processors(mock_tracer_provider, Resource.get_empty()) - self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 1) + self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 2) - # Verify the added processor is BaggageSpanProcessor - added_processor = mock_tracer_provider.add_span_processor.call_args_list[0].args[0] - self.assertIsInstance(added_processor, BaggageSpanProcessor) + first_processor = mock_tracer_provider.add_span_processor.call_args_list[0].args[0] + self.assertIsInstance(first_processor, BatchUnsampledSpanProcessor) + second_processor = mock_tracer_provider.add_span_processor.call_args_list[1].args[0] + self.assertIsInstance(second_processor, BaggageSpanProcessor) - # Clean up os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) + os.environ.pop("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", None) def test_baggage_span_processor_session_id_filtering(self): """Test that BaggageSpanProcessor only set session.id filter by default""" @@ -479,6 +515,7 @@ def test_customize_span_exporter_sigv4(self): OTLPAwsSpanExporter, AwsAuthSession, Compression.NoCompression, + Resource.get_empty(), ) for config in bad_configs: @@ -489,6 +526,7 @@ def test_customize_span_exporter_sigv4(self): OTLPSpanExporter, Session, Compression.NoCompression, + Resource.get_empty(), ) self.assertIsInstance( @@ -592,13 +630,11 @@ def test_customize_logs_exporter_sigv4(self): config, _customize_logs_exporter, OTLPLogExporter(), OTLPLogExporter, Session, Compression.NoCompression ) - self.assertIsInstance( - _customize_logs_exporter(OTLPGrpcLogExporter(), Resource.get_empty()), OTLPGrpcLogExporter - ) + self.assertIsInstance(_customize_logs_exporter(OTLPGrpcLogExporter()), OTLPGrpcLogExporter) # Need to patch all of these to prevent some weird multi-threading error with the LogProvider @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.LoggingHandler", return_value=MagicMock()) - @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.getLogger", return_value=MagicMock()) + @patch("logging.getLogger", return_value=MagicMock()) @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator._customize_logs_exporter") @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.LoggerProvider", return_value=MagicMock()) @patch( @@ -642,7 +678,6 @@ def capture_exporter(*args, **kwargs): def test_customize_span_processors(self): mock_tracer_provider: TracerProvider = MagicMock() - # Clean up environment to ensure consistent test state os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) os.environ.pop("OTEL_AWS_APPLICATION_SIGNALS_ENABLED", None) os.environ.pop("OTEL_AWS_APPLICATION_SIGNALS_RUNTIME_ENABLED", None) @@ -650,10 +685,8 @@ def test_customize_span_processors(self): _customize_span_processors(mock_tracer_provider, Resource.get_empty()) self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 0) - # Reset mock for next test mock_tracer_provider.reset_mock() - # Test application signals only os.environ.setdefault("OTEL_AWS_APPLICATION_SIGNALS_ENABLED", "True") os.environ.setdefault("OTEL_AWS_APPLICATION_SIGNALS_RUNTIME_ENABLED", "False") _customize_span_processors(mock_tracer_provider, Resource.get_empty()) @@ -663,19 +696,20 @@ def test_customize_span_processors(self): second_processor: SpanProcessor = mock_tracer_provider.add_span_processor.call_args_list[1].args[0] self.assertIsInstance(second_processor, AwsSpanMetricsProcessor) - # Reset mock for next test mock_tracer_provider.reset_mock() - # Test both agent observability and application signals enabled os.environ.setdefault("AGENT_OBSERVABILITY_ENABLED", "true") + os.environ.setdefault("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://xray.us-east-1.amazonaws.com/v1/traces") _customize_span_processors(mock_tracer_provider, Resource.get_empty()) - self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 3) + self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 4) - # Verify processors are added in the expected order processors = [call.args[0] for call in mock_tracer_provider.add_span_processor.call_args_list] - self.assertIsInstance(processors[0], BaggageSpanProcessor) # Agent observability processor added first - self.assertIsInstance(processors[1], AttributePropagatingSpanProcessor) # Application signals processors - self.assertIsInstance(processors[2], AwsSpanMetricsProcessor) + self.assertIsInstance(processors[0], BatchUnsampledSpanProcessor) + self.assertIsInstance(processors[1], BaggageSpanProcessor) + self.assertIsInstance(processors[2], AttributePropagatingSpanProcessor) + self.assertIsInstance(processors[3], AwsSpanMetricsProcessor) + + os.environ.pop("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") def test_customize_span_processors_lambda(self): mock_tracer_provider: TracerProvider = MagicMock() @@ -778,6 +812,81 @@ def test_export_unsampled_span_for_lambda(self): os.environ.pop("OTEL_AWS_APPLICATION_SIGNALS_ENABLED", None) os.environ.pop("AWS_LAMBDA_FUNCTION_NAME", None) + # pylint: disable=no-self-use + def test_export_unsampled_span_for_agent_observability(self): + mock_tracer_provider: TracerProvider = MagicMock() + + _export_unsampled_span_for_agent_observability(mock_tracer_provider, Resource.get_empty()) + self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 0) + + mock_tracer_provider.reset_mock() + + os.environ["AGENT_OBSERVABILITY_ENABLED"] = "true" + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "https://xray.us-east-1.amazonaws.com/v1/traces" + _export_unsampled_span_for_agent_observability(mock_tracer_provider, Resource.get_empty()) + self.assertEqual(mock_tracer_provider.add_span_processor.call_count, 1) + processor: SpanProcessor = mock_tracer_provider.add_span_processor.call_args_list[0].args[0] + self.assertIsInstance(processor, BatchUnsampledSpanProcessor) + + os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) + os.environ.pop("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", None) + + # pylint: disable=no-self-use + def test_export_unsampled_span_for_agent_observability_uses_aws_exporter(self): + """Test that OTLPAwsSpanExporter is used for AWS endpoints""" + mock_tracer_provider: TracerProvider = MagicMock() + + with patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.OTLPAwsSpanExporter" + ) as mock_aws_exporter: + with patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_logger_provider" + ) as mock_logger_provider: + with patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_aws_session" + ) as mock_session: + mock_session.return_value = MagicMock() + os.environ["AGENT_OBSERVABILITY_ENABLED"] = "true" + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "https://xray.us-east-1.amazonaws.com/v1/traces" + + _export_unsampled_span_for_agent_observability(mock_tracer_provider, Resource.get_empty()) + + # Verify OTLPAwsSpanExporter is created with correct parameters + mock_aws_exporter.assert_called_once_with( + session=mock_session.return_value, + endpoint="https://xray.us-east-1.amazonaws.com/v1/traces", + aws_region="us-east-1", + logger_provider=mock_logger_provider.return_value, + ) + # Verify processor is added to tracer provider + mock_tracer_provider.add_span_processor.assert_called_once() + + # Clean up + os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) + os.environ.pop("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", None) + + # pylint: disable=no-self-use + def test_customize_span_processors_calls_export_unsampled_span(self): + """Test that _customize_span_processors calls _export_unsampled_span_for_agent_observability""" + mock_tracer_provider: TracerProvider = MagicMock() + + with patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator._export_unsampled_span_for_agent_observability" + ) as mock_agent_observability: + # Test that agent observability function is NOT called when disabled + os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) + _customize_span_processors(mock_tracer_provider, Resource.get_empty()) + mock_agent_observability.assert_not_called() + + # Test that agent observability function is called when enabled + mock_agent_observability.reset_mock() + os.environ["AGENT_OBSERVABILITY_ENABLED"] = "true" + _customize_span_processors(mock_tracer_provider, Resource.get_empty()) + mock_agent_observability.assert_called_once_with(mock_tracer_provider, Resource.get_empty()) + + # Clean up + os.environ.pop("AGENT_OBSERVABILITY_ENABLED", None) + def test_customize_metric_exporter(self): metric_readers = [] views = [] @@ -808,19 +917,13 @@ def test_customize_metric_exporter(self): os.environ.pop("OTEL_METRIC_EXPORT_INTERVAL", None) def customize_exporter_test( - self, - config, - executor, - default_exporter, - expected_exporter_type, - expected_session, - expected_compression, + self, config, executor, default_exporter, expected_exporter_type, expected_session, expected_compression, *args ): for key, value in config.items(): os.environ[key] = value try: - result = executor(default_exporter, Resource.get_empty()) + result = executor(default_exporter, *args) self.assertIsInstance(result, expected_exporter_type) self.assertIsInstance(result._session, expected_session) self.assertEqual(result._compression, expected_compression) @@ -828,6 +931,323 @@ def customize_exporter_test( for key in config.keys(): os.environ.pop(key, None) + def test_check_emf_exporter_enabled(self): + # Test when OTEL_METRICS_EXPORTER is not set + os.environ.pop("OTEL_METRICS_EXPORTER", None) + self.assertFalse(_check_emf_exporter_enabled()) + + # Test when OTEL_METRICS_EXPORTER is empty + os.environ["OTEL_METRICS_EXPORTER"] = "" + self.assertFalse(_check_emf_exporter_enabled()) + + # Test when awsemf is not in the list + os.environ["OTEL_METRICS_EXPORTER"] = "console,otlp" + self.assertFalse(_check_emf_exporter_enabled()) + + # Test when awsemf is in the list + os.environ["OTEL_METRICS_EXPORTER"] = "console,awsemf,otlp" + self.assertTrue(_check_emf_exporter_enabled()) + # Should remove awsemf from the list + self.assertEqual(os.environ["OTEL_METRICS_EXPORTER"], "console,otlp") + + # Test when awsemf is the only exporter + os.environ["OTEL_METRICS_EXPORTER"] = "awsemf" + self.assertTrue(_check_emf_exporter_enabled()) + # Should remove the environment variable entirely + self.assertNotIn("OTEL_METRICS_EXPORTER", os.environ) + + # Test with spaces in the list + os.environ["OTEL_METRICS_EXPORTER"] = " console , awsemf , otlp " + self.assertTrue(_check_emf_exporter_enabled()) + self.assertEqual(os.environ["OTEL_METRICS_EXPORTER"], "console,otlp") + + # Clean up + os.environ.pop("OTEL_METRICS_EXPORTER", None) + + def test_validate_and_fetch_logs_header(self): + # Test when headers are not set + os.environ.pop(OTEL_EXPORTER_OTLP_LOGS_HEADERS, None) + result = _validate_and_fetch_logs_header() + self.assertIsInstance(result, OtlpLogHeaderSetting) + self.assertIsNone(result.log_group) + self.assertIsNone(result.log_stream) + self.assertIsNone(result.namespace) + self.assertFalse(result.is_valid) + + # Test with valid headers + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = "x-aws-log-group=test-group,x-aws-log-stream=test-stream" + result = _validate_and_fetch_logs_header() + self.assertEqual(result.log_group, "test-group") + self.assertEqual(result.log_stream, "test-stream") + self.assertIsNone(result.namespace) + self.assertTrue(result.is_valid) + + # Test with valid headers including namespace + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = ( + "x-aws-log-group=test-group,x-aws-log-stream=test-stream,x-aws-metric-namespace=test-namespace" + ) + result = _validate_and_fetch_logs_header() + self.assertEqual(result.log_group, "test-group") + self.assertEqual(result.log_stream, "test-stream") + self.assertEqual(result.namespace, "test-namespace") + self.assertTrue(result.is_valid) + + # Test with missing log group + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = "x-aws-log-stream=test-stream" + result = _validate_and_fetch_logs_header() + self.assertIsNone(result.log_group) + self.assertEqual(result.log_stream, "test-stream") + self.assertFalse(result.is_valid) + + # Test with missing log stream + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = "x-aws-log-group=test-group" + result = _validate_and_fetch_logs_header() + self.assertEqual(result.log_group, "test-group") + self.assertIsNone(result.log_stream) + self.assertFalse(result.is_valid) + + # Test with empty value in log group + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = "x-aws-log-group=,x-aws-log-stream=test-stream" + result = _validate_and_fetch_logs_header() + self.assertIsNone(result.log_group) + self.assertEqual(result.log_stream, "test-stream") + self.assertFalse(result.is_valid) + + # Test with empty value in log stream + os.environ[OTEL_EXPORTER_OTLP_LOGS_HEADERS] = "x-aws-log-group=test-group,x-aws-log-stream=" + result = _validate_and_fetch_logs_header() + self.assertEqual(result.log_group, "test-group") + self.assertIsNone(result.log_stream) + self.assertFalse(result.is_valid) + + # Clean up + os.environ.pop(OTEL_EXPORTER_OTLP_LOGS_HEADERS, None) + + @patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled", return_value=False + ) + def test_customize_log_record_processor_without_agent_observability(self, _): + """Test that BatchLogRecordProcessor is used when agent observability is not enabled""" + mock_logger_provider = MagicMock() + mock_exporter = MagicMock(spec=OTLPAwsLogExporter) + + _customize_log_record_processor(mock_logger_provider, mock_exporter) + + mock_logger_provider.add_log_record_processor.assert_called_once() + added_processor = mock_logger_provider.add_log_record_processor.call_args[0][0] + self.assertIsInstance(added_processor, BatchLogRecordProcessor) + + @patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled", return_value=True + ) + def test_customize_log_record_processor_with_agent_observability(self, _): + """Test that AwsCloudWatchOtlpBatchLogRecordProcessor is used when agent observability is enabled""" + mock_logger_provider = MagicMock() + mock_exporter = MagicMock(spec=OTLPAwsLogExporter) + + _customize_log_record_processor(mock_logger_provider, mock_exporter) + + mock_logger_provider.add_log_record_processor.assert_called_once() + added_processor = mock_logger_provider.add_log_record_processor.call_args[0][0] + self.assertIsInstance(added_processor, AwsCloudWatchOtlpBatchLogRecordProcessor) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator._validate_and_fetch_logs_header") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_aws_session") + def test_create_emf_exporter(self, mock_get_session, mock_validate): + # Test when botocore is not installed + mock_get_session.return_value = None + result = create_emf_exporter() + self.assertIsNone(result) + + # Reset mock for subsequent tests + mock_get_session.reset_mock() + mock_get_session.return_value = MagicMock() + + # Mock the EMF exporter class import by patching the module import + with patch( + "amazon.opentelemetry.distro.exporter.aws.metrics.aws_cloudwatch_emf_exporter.AwsCloudWatchEmfExporter" + ) as mock_emf_exporter_class: + mock_exporter_instance = MagicMock() + mock_exporter_instance.namespace = "default" + mock_exporter_instance.log_group_name = "test-group" + mock_emf_exporter_class.return_value = mock_exporter_instance + + # Test when headers are invalid + mock_validate.return_value = OtlpLogHeaderSetting(None, None, None, False) + result = create_emf_exporter() + self.assertIsNone(result) + + # Test when namespace is missing (should still create exporter with default namespace) + mock_validate.return_value = OtlpLogHeaderSetting("test-group", "test-stream", None, True) + result = create_emf_exporter() + self.assertIsNotNone(result) + self.assertEqual(result, mock_exporter_instance) + # Verify that the EMF exporter was called with correct parameters + mock_emf_exporter_class.assert_called_with( + session=mock_get_session.return_value, + namespace=None, + log_group_name="test-group", + log_stream_name="test-stream", + ) + + # Test with valid configuration + mock_validate.return_value = OtlpLogHeaderSetting("test-group", "test-stream", "test-namespace", True) + result = create_emf_exporter() + self.assertIsNotNone(result) + self.assertEqual(result, mock_exporter_instance) + # Verify that the EMF exporter was called with correct parameters + mock_emf_exporter_class.assert_called_with( + session=mock_get_session.return_value, + namespace="test-namespace", + log_group_name="test-group", + log_stream_name="test-stream", + ) + + # Test exception handling + mock_validate.side_effect = Exception("Test exception") + result = create_emf_exporter() + self.assertIsNone(result) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_logger_provider") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_aws_session") + def test_create_aws_otlp_exporter(self, mock_get_session, mock_is_agent_enabled, mock_get_logger_provider): + # Test when botocore is not installed + mock_get_session.return_value = None + result = _create_aws_otlp_exporter("https://xray.us-east-1.amazonaws.com/v1/traces", "xray", "us-east-1") + self.assertIsNone(result) + + # Reset mock for subsequent tests + mock_get_session.reset_mock() + mock_get_session.return_value = MagicMock() + mock_get_logger_provider.return_value = MagicMock() + + # Test xray service without agent observability + mock_is_agent_enabled.return_value = False + with patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.OTLPAwsSpanExporter" + ) as mock_span_exporter_class: + mock_exporter_instance = MagicMock() + mock_span_exporter_class.return_value = mock_exporter_instance + + result = _create_aws_otlp_exporter("https://xray.us-east-1.amazonaws.com/v1/traces", "xray", "us-east-1") + self.assertIsNotNone(result) + self.assertEqual(result, mock_exporter_instance) + mock_span_exporter_class.assert_called_with( + session=mock_get_session.return_value, + endpoint="https://xray.us-east-1.amazonaws.com/v1/traces", + aws_region="us-east-1", + ) + + # Test xray service with agent observability + mock_is_agent_enabled.return_value = True + with patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter.OTLPAwsSpanExporter" + ) as mock_span_exporter_class: + mock_exporter_instance = MagicMock() + mock_span_exporter_class.return_value = mock_exporter_instance + + result = _create_aws_otlp_exporter("https://xray.us-east-1.amazonaws.com/v1/traces", "xray", "us-east-1") + self.assertIsNotNone(result) + self.assertEqual(result, mock_exporter_instance) + mock_span_exporter_class.assert_called_with( + session=mock_get_session.return_value, + endpoint="https://xray.us-east-1.amazonaws.com/v1/traces", + aws_region="us-east-1", + logger_provider=mock_get_logger_provider.return_value, + ) + + # Test logs service + with patch( + "amazon.opentelemetry.distro.exporter.otlp.aws.logs.otlp_aws_logs_exporter.OTLPAwsLogExporter" + ) as mock_log_exporter_class: + mock_exporter_instance = MagicMock() + mock_log_exporter_class.return_value = mock_exporter_instance + + result = _create_aws_otlp_exporter("https://logs.us-east-1.amazonaws.com/v1/logs", "logs", "us-east-1") + self.assertIsNotNone(result) + self.assertEqual(result, mock_exporter_instance) + mock_log_exporter_class.assert_called_with(session=mock_get_session.return_value, aws_region="us-east-1") + + # Test exception handling + mock_get_session.side_effect = Exception("Test exception") + result = _create_aws_otlp_exporter("https://xray.us-east-1.amazonaws.com/v1/traces", "xray", "us-east-1") + self.assertIsNone(result) + + def test_customize_metric_exporters_with_emf(self): + metric_readers = [] + views = [] + + # Test with EMF disabled + _customize_metric_exporters(metric_readers, views, is_emf_enabled=False) + self.assertEqual(len(metric_readers), 0) + + # Test with EMF enabled but create_emf_exporter returns None + with patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.create_emf_exporter", return_value=None): + _customize_metric_exporters(metric_readers, views, is_emf_enabled=True) + self.assertEqual(len(metric_readers), 0) + + # Test with EMF enabled and valid exporter + mock_emf_exporter = MagicMock(spec=AwsCloudWatchEmfExporter) + # Add the required attributes that PeriodicExportingMetricReader expects + mock_emf_exporter._preferred_temporality = {} + mock_emf_exporter._preferred_aggregation = {} + + with patch( + "amazon.opentelemetry.distro.aws_opentelemetry_configurator.create_emf_exporter", + return_value=mock_emf_exporter, + ): + _customize_metric_exporters(metric_readers, views, is_emf_enabled=True) + self.assertEqual(len(metric_readers), 1) + self.assertIsInstance(metric_readers[0], PeriodicExportingMetricReader) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_service_attribute") + def test_customize_resource_without_agent_observability(self, mock_get_service_attribute, mock_is_agent_enabled): + """Test _customize_resource when agent observability is disabled""" + mock_is_agent_enabled.return_value = False + mock_get_service_attribute.return_value = ("test-service", False) + + resource = Resource.create({ResourceAttributes.SERVICE_NAME: "test-service"}) + result = _customize_resource(resource) + + # Should only have AWS_LOCAL_SERVICE added + self.assertEqual(result.attributes[AWS_LOCAL_SERVICE], "test-service") + self.assertNotIn(AWS_SERVICE_TYPE, result.attributes) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_service_attribute") + def test_customize_resource_with_agent_observability_default( + self, mock_get_service_attribute, mock_is_agent_enabled + ): + """Test _customize_resource when agent observability is enabled with default agent type""" + mock_is_agent_enabled.return_value = True + mock_get_service_attribute.return_value = ("test-service", False) + + resource = Resource.create({ResourceAttributes.SERVICE_NAME: "test-service"}) + result = _customize_resource(resource) + + # Should have both AWS_LOCAL_SERVICE and AWS_SERVICE_TYPE with default value + self.assertEqual(result.attributes[AWS_LOCAL_SERVICE], "test-service") + self.assertEqual(result.attributes[AWS_SERVICE_TYPE], "gen_ai_agent") + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_configurator.get_service_attribute") + def test_customize_resource_with_existing_agent_type(self, mock_get_service_attribute, mock_is_agent_enabled): + """Test _customize_resource when agent type already exists in resource""" + mock_is_agent_enabled.return_value = True + mock_get_service_attribute.return_value = ("test-service", False) + + # Create resource with existing agent type + resource = Resource.create( + {ResourceAttributes.SERVICE_NAME: "test-service", AWS_SERVICE_TYPE: "existing-agent"} + ) + result = _customize_resource(resource) + + # Should preserve existing agent type and not override it + self.assertEqual(result.attributes[AWS_LOCAL_SERVICE], "test-service") + self.assertEqual(result.attributes[AWS_SERVICE_TYPE], "existing-agent") + def validate_distro_environ(): tc: TestCase = TestCase() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelemetry_distro.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelemetry_distro.py index b77e4fbf8..5a044c5eb 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelemetry_distro.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_aws_opentelemetry_distro.py @@ -1,13 +1,226 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import os +from importlib.metadata import PackageNotFoundError, version from unittest import TestCase +from unittest.mock import patch -from pkg_resources import DistributionNotFound, require +from amazon.opentelemetry.distro.aws_opentelemetry_distro import AwsOpenTelemetryDistro class TestAwsOpenTelemetryDistro(TestCase): + def setUp(self): + # Store original env vars if they exist + self.env_vars_to_restore = {} + self.env_vars_to_check = [ + "OTEL_EXPORTER_OTLP_PROTOCOL", + "OTEL_PROPAGATORS", + "OTEL_PYTHON_ID_GENERATOR", + "OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION", + "AGENT_OBSERVABILITY_ENABLED", + "OTEL_TRACES_EXPORTER", + "OTEL_LOGS_EXPORTER", + "OTEL_METRICS_EXPORTER", + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + "OTEL_TRACES_SAMPLER", + "OTEL_PYTHON_DISABLED_INSTRUMENTATIONS", + "OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED", + "OTEL_AWS_APPLICATION_SIGNALS_ENABLED", + ] + + # First, save all current values + for var in self.env_vars_to_check: + if var in os.environ: + self.env_vars_to_restore[var] = os.environ[var] + + # Then clear ALL of them to ensure clean state + for var in self.env_vars_to_check: + if var in os.environ: + del os.environ[var] + + def tearDown(self): + # Clear all env vars first + for var in self.env_vars_to_check: + if var in os.environ: + del os.environ[var] + + # Then restore original values + for var, value in self.env_vars_to_restore.items(): + os.environ[var] = value + def test_package_available(self): try: - require(["aws-opentelemetry-distro"]) - except DistributionNotFound: + version("aws-opentelemetry-distro") + except PackageNotFoundError: self.fail("aws-opentelemetry-distro not installed") + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_sets_default_values(self, mock_super_configure, mock_apply_patches): + """Test that _configure sets default environment variables""" + distro = AwsOpenTelemetryDistro() + distro._configure(apply_patches=True) + + # Check that default values are set + self.assertEqual(os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL"), "http/protobuf") + self.assertEqual(os.environ.get("OTEL_PROPAGATORS"), "xray,tracecontext,b3,b3multi") + self.assertEqual(os.environ.get("OTEL_PYTHON_ID_GENERATOR"), "xray") + self.assertEqual( + os.environ.get("OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION"), + "base2_exponential_bucket_histogram", + ) + + # Verify super()._configure() was called + mock_super_configure.assert_called_once() + + # Verify patches were applied + mock_apply_patches.assert_called_once() + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_without_patches(self, mock_super_configure, mock_apply_patches): # pylint: disable=no-self-use + """Test that _configure can skip applying patches""" + distro = AwsOpenTelemetryDistro() + distro._configure(apply_patches=False) + + # Verify patches were NOT applied + mock_apply_patches.assert_not_called() + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.get_aws_region") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_with_agent_observability_enabled( + self, mock_super_configure, mock_apply_patches, mock_is_agent_observability, mock_get_aws_region + ): + """Test that _configure sets agent observability defaults when enabled""" + mock_is_agent_observability.return_value = True + mock_get_aws_region.return_value = "us-west-2" + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Check agent observability defaults + self.assertEqual(os.environ.get("OTEL_TRACES_EXPORTER"), "otlp") + self.assertEqual(os.environ.get("OTEL_LOGS_EXPORTER"), "otlp") + self.assertEqual(os.environ.get("OTEL_METRICS_EXPORTER"), "awsemf") + self.assertEqual(os.environ.get("OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"), "true") + self.assertEqual( + os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"), "https://xray.us-west-2.amazonaws.com/v1/traces" + ) + self.assertEqual( + os.environ.get("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT"), "https://logs.us-west-2.amazonaws.com/v1/logs" + ) + self.assertEqual(os.environ.get("OTEL_TRACES_SAMPLER"), "parentbased_always_on") + self.assertEqual( + os.environ.get("OTEL_PYTHON_DISABLED_INSTRUMENTATIONS"), + "http,sqlalchemy,psycopg2,pymysql,sqlite3,aiopg,asyncpg,mysql_connector," + "botocore,boto3,urllib3,requests,starlette", + ) + self.assertEqual(os.environ.get("OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED"), "true") + self.assertEqual(os.environ.get("OTEL_AWS_APPLICATION_SIGNALS_ENABLED"), "false") + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.get_aws_region") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_with_agent_observability_no_region( + self, mock_super_configure, mock_apply_patches, mock_is_agent_observability, mock_get_aws_region + ): + """Test that _configure handles missing AWS region gracefully""" + mock_is_agent_observability.return_value = True + mock_get_aws_region.return_value = None # No region found + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Check that OTLP endpoints are not set when region is not available + self.assertNotIn("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", os.environ) + self.assertNotIn("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", os.environ) + + # But verify that the exporters are still set to otlp (will use default endpoints) + self.assertEqual(os.environ.get("OTEL_TRACES_EXPORTER"), "otlp") + self.assertEqual(os.environ.get("OTEL_LOGS_EXPORTER"), "otlp") + self.assertEqual(os.environ.get("OTEL_METRICS_EXPORTER"), "awsemf") + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_with_agent_observability_disabled( + self, mock_super_configure, mock_apply_patches, mock_is_agent_observability + ): + """Test that _configure doesn't set agent observability defaults when disabled""" + mock_is_agent_observability.return_value = False + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Check that agent observability defaults are not set + self.assertNotIn("OTEL_TRACES_EXPORTER", os.environ) + self.assertNotIn("OTEL_LOGS_EXPORTER", os.environ) + self.assertNotIn("OTEL_METRICS_EXPORTER", os.environ) + self.assertNotIn("OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", os.environ) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.get_aws_region") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_preserves_existing_env_vars( + self, mock_super_configure, mock_apply_patches, mock_is_agent_observability, mock_get_aws_region + ): + """Test that _configure doesn't override existing environment variables""" + mock_is_agent_observability.return_value = True + mock_get_aws_region.return_value = "us-east-1" + + # Set existing values + os.environ["OTEL_TRACES_EXPORTER"] = "custom_exporter" + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "https://custom.endpoint.com" + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Check that existing values are preserved + self.assertEqual(os.environ.get("OTEL_TRACES_EXPORTER"), "custom_exporter") + self.assertEqual(os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"), "https://custom.endpoint.com") + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + @patch("os.getcwd") + @patch("sys.path", new_callable=list) + def test_configure_adds_cwd_to_sys_path(self, mock_sys_path, mock_getcwd, mock_super_configure, mock_apply_patches): + """Test that _configure adds current working directory to sys.path""" + mock_getcwd.return_value = "/test/working/directory" + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Check that cwd was added to sys.path + self.assertIn("/test/working/directory", mock_sys_path) + + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.get_aws_region") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.is_agent_observability_enabled") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.apply_instrumentation_patches") + @patch("amazon.opentelemetry.distro.aws_opentelemetry_distro.OpenTelemetryDistro._configure") + def test_configure_with_agent_observability_endpoints_already_set( + self, mock_super_configure, mock_apply_patches, mock_is_agent_observability, mock_get_aws_region + ): + """Test that user-provided OTLP endpoints are preserved even when region detection fails""" + mock_is_agent_observability.return_value = True + mock_get_aws_region.return_value = None # No region found + + # User has already set custom endpoints + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = "https://my-custom-traces.example.com" + os.environ["OTEL_EXPORTER_OTLP_LOGS_ENDPOINT"] = "https://my-custom-logs.example.com" + + distro = AwsOpenTelemetryDistro() + distro._configure() + + # Verify that user-provided endpoints are preserved + self.assertEqual(os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"), "https://my-custom-traces.example.com") + self.assertEqual(os.environ.get("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT"), "https://my-custom-logs.example.com") + + # And exporters are still set to otlp + self.assertEqual(os.environ.get("OTEL_TRACES_EXPORTER"), "otlp") + self.assertEqual(os.environ.get("OTEL_LOGS_EXPORTER"), "otlp") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_instrumentation_patch.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_instrumentation_patch.py index 87e6c4810..8eff6f2e6 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_instrumentation_patch.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_instrumentation_patch.py @@ -1,17 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import json -import math import os -from io import BytesIO +from importlib.metadata import PackageNotFoundError from typing import Any, Dict from unittest import TestCase from unittest.mock import MagicMock, patch import gevent.monkey -import pkg_resources -from botocore.response import StreamingBody +import opentelemetry.sdk.extension.aws.resource.ec2 as ec2_resource +import opentelemetry.sdk.extension.aws.resource.eks as eks_resource from amazon.opentelemetry.distro.patches._instrumentation_patch import ( AWS_GEVENT_PATCH_MODULES, apply_instrumentation_patches, @@ -38,7 +36,7 @@ _LAMBDA_SOURCE_MAPPING_ID: str = "lambdaEventSourceMappingID" # Patch names -GET_DISTRIBUTION_PATCH: str = "amazon.opentelemetry.distro._utils.pkg_resources.get_distribution" +IMPORTLIB_METADATA_VERSION_PATCH: str = "amazon.opentelemetry.distro._utils.version" class TestInstrumentationPatch(TestCase): @@ -60,7 +58,7 @@ class TestInstrumentationPatch(TestCase): def test_instrumentation_patch(self): # Set up method patches used by all tests - self.method_patches[GET_DISTRIBUTION_PATCH] = patch(GET_DISTRIBUTION_PATCH).start() + self.method_patches[IMPORTLIB_METADATA_VERSION_PATCH] = patch(IMPORTLIB_METADATA_VERSION_PATCH).start() # Run tests that validate patch behaviour before and after patching self._run_patch_behaviour_tests() @@ -73,7 +71,7 @@ def test_instrumentation_patch(self): def _run_patch_behaviour_tests(self): # Test setup - self.method_patches[GET_DISTRIBUTION_PATCH].return_value = "CorrectDistributionObject" + self.method_patches[IMPORTLIB_METADATA_VERSION_PATCH].return_value = "1.0.0" # Test setup to not patch gevent os.environ[AWS_GEVENT_PATCH_MODULES] = "none" @@ -120,6 +118,8 @@ def _run_patch_mechanism_tests(self): """ self._test_botocore_installed_flag() self._reset_mocks() + self._test_resource_detector_patches() + self._reset_mocks() def _test_unpatched_botocore_instrumentation(self): # Kinesis @@ -147,7 +147,7 @@ def _test_unpatched_botocore_instrumentation(self): ) # BedrockRuntime - self.assertFalse("bedrock-runtime" in _KNOWN_EXTENSIONS, "Upstream has added a bedrock-runtime extension") + self.assertTrue("bedrock-runtime" in _KNOWN_EXTENSIONS, "Upstream has added a bedrock-runtime extension") # SecretsManager self.assertFalse("secretsmanager" in _KNOWN_EXTENSIONS, "Upstream has added a SecretsManager extension") @@ -213,95 +213,9 @@ def _test_patched_botocore_instrumentation(self): bedrock_agent_runtime_sucess_attributes: Dict[str, str] = _do_on_success_bedrock("bedrock-agent-runtime") self.assertEqual(len(bedrock_agent_runtime_sucess_attributes), 0) - # BedrockRuntime - Amazon Titan + # BedrockRuntime self.assertTrue("bedrock-runtime" in _KNOWN_EXTENSIONS) - self._test_patched_bedrock_runtime_invoke_model( - model_id="amazon.titan-embed-text-v1", - max_tokens=512, - temperature=0.9, - top_p=0.75, - finish_reason="FINISH", - input_tokens=123, - output_tokens=456, - ) - - self._test_patched_bedrock_runtime_invoke_model( - model_id="amazon.nova-pro-v1:0", - max_tokens=500, - temperature=0.9, - top_p=0.7, - finish_reason="FINISH", - input_tokens=123, - output_tokens=456, - ) - - # BedrockRuntime - Anthropic Claude - self._test_patched_bedrock_runtime_invoke_model( - model_id="anthropic.claude-v2:1", - max_tokens=512, - temperature=0.5, - top_p=0.999, - finish_reason="end_turn", - input_tokens=23, - output_tokens=36, - ) - - # BedrockRuntime - Meta LLama - self._test_patched_bedrock_runtime_invoke_model( - model_id="meta.llama2-13b-chat-v1", - max_tokens=512, - temperature=0.5, - top_p=0.9, - finish_reason="stop", - input_tokens=31, - output_tokens=36, - ) - - # BedrockRuntime - Cohere Command-r - cohere_input = "Hello, world" - cohere_output = "Goodbye, world" - - self._test_patched_bedrock_runtime_invoke_model( - model_id="cohere.command-r-v1:0", - max_tokens=512, - temperature=0.5, - top_p=0.75, - finish_reason="COMPLETE", - input_tokens=math.ceil(len(cohere_input) / 6), - output_tokens=math.ceil(len(cohere_output) / 6), - input_prompt=cohere_input, - output_prompt=cohere_output, - ) - - # BedrockRuntime - AI21 Jambda - self._test_patched_bedrock_runtime_invoke_model( - model_id="ai21.jamba-1-5-large-v1:0", - max_tokens=512, - temperature=0.5, - top_p=0.999, - finish_reason="end_turn", - input_tokens=23, - output_tokens=36, - ) - - # BedrockRuntime - Mistral - msg = "Hello World" - mistral_input = f"[INST] {msg} [/INST]" - mistral_output = "Goodbye, World" - - self._test_patched_bedrock_runtime_invoke_model( - model_id="mistral.mistral-7b-instruct-v0:2", - max_tokens=512, - temperature=0.5, - top_p=0.9, - finish_reason="stop", - input_tokens=math.ceil(len(mistral_input) / 6), - output_tokens=math.ceil(len(mistral_output) / 6), - input_prompt=mistral_input, - output_prompt=mistral_output, - ) - # SecretsManager self.assertTrue("secretsmanager" in _KNOWN_EXTENSIONS) secretsmanager_attributes: Dict[str, str] = _do_extract_secretsmanager_attributes() @@ -369,17 +283,13 @@ def _test_botocore_installed_flag(self): with patch( "amazon.opentelemetry.distro.patches._botocore_patches._apply_botocore_instrumentation_patches" ) as mock_apply_patches: - get_distribution_patch: patch = self.method_patches[GET_DISTRIBUTION_PATCH] - get_distribution_patch.side_effect = pkg_resources.DistributionNotFound - apply_instrumentation_patches() - mock_apply_patches.assert_not_called() - - get_distribution_patch.side_effect = pkg_resources.VersionConflict("botocore==1.0.0", "botocore==0.0.1") + get_distribution_patch: patch = self.method_patches[IMPORTLIB_METADATA_VERSION_PATCH] + get_distribution_patch.side_effect = PackageNotFoundError apply_instrumentation_patches() mock_apply_patches.assert_not_called() get_distribution_patch.side_effect = None - get_distribution_patch.return_value = "CorrectDistributionObject" + get_distribution_patch.return_value = "1.0.0" apply_instrumentation_patches() mock_apply_patches.assert_called() @@ -389,146 +299,6 @@ def _test_patched_bedrock_instrumentation(self): self.assertEqual(len(bedrock_sucess_attributes), 1) self.assertEqual(bedrock_sucess_attributes["aws.bedrock.guardrail.id"], _BEDROCK_GUARDRAIL_ID) - def _test_patched_bedrock_runtime_invoke_model(self, **args): - model_id = args.get("model_id", None) - max_tokens = args.get("max_tokens", None) - temperature = args.get("temperature", None) - top_p = args.get("top_p", None) - finish_reason = args.get("finish_reason", None) - input_tokens = args.get("input_tokens", None) - output_tokens = args.get("output_tokens", None) - input_prompt = args.get("input_prompt", None) - output_prompt = args.get("output_prompt", None) - - def get_model_response_request(): - request_body = {} - response_body = {} - - if "amazon.titan" in model_id: - request_body = { - "textGenerationConfig": { - "maxTokenCount": max_tokens, - "temperature": temperature, - "topP": top_p, - } - } - - response_body = { - "inputTextTokenCount": input_tokens, - "results": [ - { - "tokenCount": output_tokens, - "outputText": "testing", - "completionReason": finish_reason, - } - ], - } - - if "amazon.nova" in model_id: - request_body = { - "inferenceConfig": { - "max_new_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - } - - response_body = { - "output": {"message": {"content": [{"text": ""}], "role": "assistant"}}, - "stopReason": finish_reason, - "usage": {"inputTokens": input_tokens, "outputTokens": output_tokens}, - } - - if "anthropic.claude" in model_id: - request_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - - response_body = { - "stop_reason": finish_reason, - "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, - } - - if "ai21.jamba" in model_id: - request_body = { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - - response_body = { - "choices": [{"finish_reason": finish_reason}], - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": (input_tokens + output_tokens), - }, - } - - if "meta.llama" in model_id: - request_body = { - "max_gen_len": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - - response_body = { - "prompt_token_count": input_tokens, - "generation_token_count": output_tokens, - "stop_reason": finish_reason, - } - - if "cohere.command" in model_id: - request_body = { - "message": input_prompt, - "max_tokens": max_tokens, - "temperature": temperature, - "p": top_p, - } - - response_body = { - "text": output_prompt, - "finish_reason": finish_reason, - } - - if "mistral" in model_id: - request_body = { - "prompt": input_prompt, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } - - response_body = {"outputs": [{"text": output_prompt, "stop_reason": finish_reason}]} - - json_bytes = json.dumps(response_body).encode("utf-8") - - return json.dumps(request_body), StreamingBody(BytesIO(json_bytes), len(json_bytes)) - - request_body, response_body = get_model_response_request() - - bedrock_runtime_attributes: Dict[str, str] = _do_extract_attributes_bedrock( - "bedrock-runtime", model_id=model_id, request_body=request_body - ) - bedrock_runtime_success_attributes: Dict[str, str] = _do_on_success_bedrock( - "bedrock-runtime", model_id=model_id, streaming_body=response_body - ) - - bedrock_runtime_attributes.update(bedrock_runtime_success_attributes) - - self.assertEqual(bedrock_runtime_attributes["gen_ai.system"], _GEN_AI_SYSTEM) - self.assertEqual(bedrock_runtime_attributes["gen_ai.request.model"], model_id) - self.assertEqual(bedrock_runtime_attributes["gen_ai.request.max_tokens"], max_tokens) - self.assertEqual(bedrock_runtime_attributes["gen_ai.request.temperature"], temperature) - self.assertEqual(bedrock_runtime_attributes["gen_ai.request.top_p"], top_p) - self.assertEqual(bedrock_runtime_attributes["gen_ai.usage.input_tokens"], input_tokens) - self.assertEqual(bedrock_runtime_attributes["gen_ai.usage.output_tokens"], output_tokens) - self.assertEqual(bedrock_runtime_attributes["gen_ai.response.finish_reasons"], [finish_reason]) - def _test_patched_bedrock_agent_instrumentation(self): """For bedrock-agent service, both extract_attributes and on_success provides attributes, the attributes depend on the API being invoked.""" @@ -586,6 +356,53 @@ def _test_patched_bedrock_agent_instrumentation(self): self.assertEqual(len(bedrock_agent_success_attributes), 1) self.assertEqual(bedrock_agent_success_attributes[attribute_tuple[0]], attribute_tuple[1]) + def _test_resource_detector_patches(self): + """Test that resource detector patches are applied and work correctly""" + # Test that the functions were patched + self.assertIsNotNone(ec2_resource._aws_http_request) + self.assertIsNotNone(eks_resource._aws_http_request) + + # Test EC2 patched function + with patch("amazon.opentelemetry.distro.patches._resource_detector_patches.urlopen") as mock_urlopen: + mock_response = MagicMock() + mock_response.read.return_value = b'{"test": "ec2-data"}' + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = ec2_resource._aws_http_request("GET", "/test/path", {"X-Test": "header"}) + self.assertEqual(result, '{"test": "ec2-data"}') + + # Verify the request was made correctly + args, kwargs = mock_urlopen.call_args + request = args[0] + self.assertEqual(request.full_url, "http://169.254.169.254/test/path") + self.assertEqual(request.headers, {"X-test": "header"}) + self.assertEqual(kwargs["timeout"], 5) + + # Test EKS patched function + with patch("amazon.opentelemetry.distro.patches._resource_detector_patches.urlopen") as mock_urlopen, patch( + "amazon.opentelemetry.distro.patches._resource_detector_patches.ssl.create_default_context" + ) as mock_ssl: + mock_response = MagicMock() + mock_response.read.return_value = b'{"test": "eks-data"}' + mock_urlopen.return_value.__enter__.return_value = mock_response + + mock_context = MagicMock() + mock_ssl.return_value = mock_context + + result = eks_resource._aws_http_request("GET", "/api/v1/test", "Bearer token123") + self.assertEqual(result, '{"test": "eks-data"}') + + # Verify the request was made correctly + args, kwargs = mock_urlopen.call_args + request = args[0] + self.assertEqual(request.full_url, "https://kubernetes.default.svc/api/v1/test") + self.assertEqual(request.headers, {"Authorization": "Bearer token123"}) + self.assertEqual(kwargs["timeout"], 5) + self.assertEqual(kwargs["context"], mock_context) + + # Verify SSL context was created with correct CA file + mock_ssl.assert_called_once_with(cafile="/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + def _reset_mocks(self): for method_patch in self.method_patches.values(): method_patch.reset_mock() @@ -678,6 +495,7 @@ def _do_on_success( ) -> Dict[str, str]: span_mock: Span = MagicMock() mock_call_context = MagicMock() + mock_instrumentor_context = MagicMock() span_attributes: Dict[str, str] = {} def set_side_effect(set_key, set_value): @@ -692,6 +510,6 @@ def set_side_effect(set_key, set_value): mock_call_context.params = params extension = _KNOWN_EXTENSIONS[service_name]()(mock_call_context) - extension.on_success(span_mock, result) + extension.on_success(span_mock, result, mock_instrumentor_context) return span_attributes diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_llo_handler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_llo_handler.py new file mode 100644 index 000000000..f2c0b36b4 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_llo_handler.py @@ -0,0 +1,40 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest import TestCase +from unittest.mock import MagicMock + +from amazon.opentelemetry.distro.llo_handler import LLOHandler +from opentelemetry.sdk._logs import LoggerProvider + + +class TestLLOHandler(TestCase): + def test_init_with_logger_provider(self): + # Test LLOHandler initialization with a logger provider + mock_logger_provider = MagicMock(spec=LoggerProvider) + + handler = LLOHandler(logger_provider=mock_logger_provider) + + # Since the __init__ method only has 'pass' in the implementation, + # we can only verify that the handler is created without errors + self.assertIsInstance(handler, LLOHandler) + + def test_init_stores_logger_provider(self): + # Test that logger provider is stored (if implementation is added) + mock_logger_provider = MagicMock(spec=LoggerProvider) + + handler = LLOHandler(logger_provider=mock_logger_provider) + + # This test assumes the implementation will store the logger_provider + # When the actual implementation is added, update this test accordingly + self.assertIsInstance(handler, LLOHandler) + + def test_process_spans_method_exists(self): # pylint: disable=no-self-use + # Test that process_spans method exists (for interface contract) + mock_logger_provider = MagicMock(spec=LoggerProvider) + LLOHandler(logger_provider=mock_logger_provider) + + # Verify the handler has the process_spans method + # This will fail until the method is implemented + # self.assertTrue(hasattr(handler, 'process_spans')) + # self.assertTrue(callable(getattr(handler, 'process_spans', None))) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_otlp_aws_span_exporter.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_otlp_aws_span_exporter.py deleted file mode 100644 index 973849f69..000000000 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_otlp_aws_span_exporter.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -from unittest import TestCase -from unittest.mock import MagicMock - -from amazon.opentelemetry.distro.exporter.otlp.aws.traces.otlp_aws_span_exporter import OTLPAwsSpanExporter -from opentelemetry.sdk._logs import LoggerProvider - - -class TestOTLPAwsSpanExporter(TestCase): - def test_init_with_logger_provider(self): - # Test initialization with logger_provider - mock_logger_provider = MagicMock(spec=LoggerProvider) - endpoint = "https://xray.us-east-1.amazonaws.com/v1/traces" - - exporter = OTLPAwsSpanExporter(endpoint=endpoint, logger_provider=mock_logger_provider) - - self.assertEqual(exporter._logger_provider, mock_logger_provider) - self.assertEqual(exporter._aws_region, "us-east-1") - - def test_init_without_logger_provider(self): - # Test initialization without logger_provider (default behavior) - endpoint = "https://xray.us-west-2.amazonaws.com/v1/traces" - - exporter = OTLPAwsSpanExporter(endpoint=endpoint) - - self.assertIsNone(exporter._logger_provider) - self.assertEqual(exporter._aws_region, "us-west-2") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_utils.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_utils.py new file mode 100644 index 000000000..4c0cd709f --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_utils.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from importlib.metadata import PackageNotFoundError +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from amazon.opentelemetry.distro._utils import ( + AGENT_OBSERVABILITY_ENABLED, + get_aws_region, + get_aws_session, + is_agent_observability_enabled, + is_installed, +) + + +class TestUtils(TestCase): + def setUp(self): + # Store original env var if it exists + self.original_env = os.environ.get(AGENT_OBSERVABILITY_ENABLED) + # Clear it to ensure clean state + if AGENT_OBSERVABILITY_ENABLED in os.environ: + del os.environ[AGENT_OBSERVABILITY_ENABLED] + + def tearDown(self): + # First clear the env var + if AGENT_OBSERVABILITY_ENABLED in os.environ: + del os.environ[AGENT_OBSERVABILITY_ENABLED] + # Then restore original if it existed + if self.original_env is not None: + os.environ[AGENT_OBSERVABILITY_ENABLED] = self.original_env + + def test_is_installed_package_not_found(self): + """Test is_installed returns False when package is not found""" + with patch("amazon.opentelemetry.distro._utils.version") as mock_version: + # Simulate package not found + mock_version.side_effect = PackageNotFoundError("test-package") + + result = is_installed("test-package>=1.0.0") + self.assertFalse(result) + + def test_is_installed(self): + """Test is_installed returns True when version matches the specifier""" + with patch("amazon.opentelemetry.distro._utils.version") as mock_version: + # Package is installed and version matches requirement + mock_version.return_value = "2.5.0" + + # Test with compatible version requirement + result = is_installed("test-package>=2.0.0") + self.assertTrue(result) + + # Test with exact version match + mock_version.return_value = "1.0.0" + result = is_installed("test-package==1.0.0") + self.assertTrue(result) + + # Test with version range + mock_version.return_value = "1.5.0" + result = is_installed("test-package>=1.0,<2.0") + self.assertTrue(result) + + def test_is_installed_version_mismatch(self): + """Test is_installed returns False when version doesn't match""" + with patch("amazon.opentelemetry.distro._utils.version") as mock_version: + # Package is installed but version doesn't match requirement + mock_version.return_value = "1.0.0" + + # Test with incompatible version requirement + result = is_installed("test-package>=2.0.0") + self.assertFalse(result) + + def test_is_agent_observability_enabled_various_values(self): + """Test is_agent_observability_enabled with various environment variable values""" + # Test with "True" (uppercase) + os.environ[AGENT_OBSERVABILITY_ENABLED] = "True" + self.assertTrue(is_agent_observability_enabled()) + + # Test with "TRUE" (all caps) + os.environ[AGENT_OBSERVABILITY_ENABLED] = "TRUE" + self.assertTrue(is_agent_observability_enabled()) + + # Test with "true" (lowercase) + os.environ[AGENT_OBSERVABILITY_ENABLED] = "true" + self.assertTrue(is_agent_observability_enabled()) + + # Test with "false" + os.environ[AGENT_OBSERVABILITY_ENABLED] = "false" + self.assertFalse(is_agent_observability_enabled()) + + # Test with "False" + os.environ[AGENT_OBSERVABILITY_ENABLED] = "False" + self.assertFalse(is_agent_observability_enabled()) + + # Test with arbitrary string + os.environ[AGENT_OBSERVABILITY_ENABLED] = "yes" + self.assertFalse(is_agent_observability_enabled()) + + # Test with empty string + os.environ[AGENT_OBSERVABILITY_ENABLED] = "" + self.assertFalse(is_agent_observability_enabled()) + + # Test when env var is not set + if AGENT_OBSERVABILITY_ENABLED in os.environ: + del os.environ[AGENT_OBSERVABILITY_ENABLED] + self.assertFalse(is_agent_observability_enabled()) + + def test_get_aws_session_with_botocore(self): + """Test get_aws_session when botocore is installed""" + with patch("amazon.opentelemetry.distro._utils.IS_BOTOCORE_INSTALLED", True): + with patch("botocore.session.Session") as mock_session_class: + mock_session = MagicMock() + mock_session_class.return_value = mock_session + + session = get_aws_session() + self.assertEqual(session, mock_session) + mock_session_class.assert_called_once() + + def test_get_aws_session_without_botocore(self): + """Test get_aws_session when botocore is not installed""" + with patch("amazon.opentelemetry.distro._utils.IS_BOTOCORE_INSTALLED", False): + session = get_aws_session() + self.assertIsNone(session) + + def test_get_aws_region_with_botocore(self): + """Test get_aws_region when botocore is available and returns a region""" + with patch("amazon.opentelemetry.distro._utils.get_aws_session") as mock_get_session: + mock_session = MagicMock() + mock_session.get_config_variable.return_value = "us-east-1" + mock_get_session.return_value = mock_session + + region = get_aws_region() + self.assertEqual(region, "us-east-1") + mock_session.get_config_variable.assert_called_once_with("region") + + def test_get_aws_region_without_botocore(self): + """Test get_aws_region when botocore is not installed""" + with patch("amazon.opentelemetry.distro._utils.get_aws_session") as mock_get_session: + mock_get_session.return_value = None + + region = get_aws_region() + self.assertIsNone(region) + + def test_get_aws_region_botocore_no_region(self): + """Test get_aws_region when botocore is available but returns no region""" + with patch("amazon.opentelemetry.distro._utils.get_aws_session") as mock_get_session: + mock_session = MagicMock() + mock_session.get_config_variable.return_value = None + mock_get_session.return_value = mock_session + + region = get_aws_region() + self.assertIsNone(region) + mock_session.get_config_variable.assert_called_once_with("region") + + def test_get_aws_region_with_aws_region_env(self): + """Test get_aws_region when AWS_REGION environment variable is set""" + os.environ.pop("AWS_REGION", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + os.environ["AWS_REGION"] = "us-west-2" + + region = get_aws_region() + self.assertEqual(region, "us-west-2") + + os.environ.pop("AWS_REGION", None) + + def test_get_aws_region_with_aws_default_region_env(self): + """Test get_aws_region when AWS_DEFAULT_REGION environment variable is set""" + os.environ.pop("AWS_REGION", None) + os.environ.pop("AWS_DEFAULT_REGION", None) + os.environ["AWS_DEFAULT_REGION"] = "eu-west-1" + + region = get_aws_region() + self.assertEqual(region, "eu-west-1") + + os.environ.pop("AWS_DEFAULT_REGION", None) diff --git a/contract-tests/images/applications/botocore/botocore_server.py b/contract-tests/images/applications/botocore/botocore_server.py index 6c315a4dc..80ecbc6fe 100644 --- a/contract-tests/images/applications/botocore/botocore_server.py +++ b/contract-tests/images/applications/botocore/botocore_server.py @@ -435,7 +435,7 @@ def get_model_request_response(path): "inferenceConfig": { "max_new_tokens": 800, "temperature": 0.9, - "top_p": 0.7, + "topP": 0.7, }, } @@ -496,32 +496,6 @@ def get_model_request_response(path): "text": "test-generation-text", } - if "ai21.jamba" in path: - model_id = "ai21.jamba-1-5-large-v1:0" - - request_body = { - "messages": [ - { - "role": "user", - "content": prompt, - }, - ], - "top_p": 0.8, - "temperature": 0.6, - "max_tokens": 512, - } - - response_body = { - "stop_reason": "end_turn", - "usage": { - "prompt_tokens": 21, - "completion_tokens": 24, - }, - "choices": [ - {"finish_reason": "stop"}, - ], - } - if "mistral" in path: model_id = "mistral.mistral-7b-instruct-v0:2" diff --git a/contract-tests/images/applications/botocore/requirements.txt b/contract-tests/images/applications/botocore/requirements.txt index 25113e3f4..61ddebf98 100644 --- a/contract-tests/images/applications/botocore/requirements.txt +++ b/contract-tests/images/applications/botocore/requirements.txt @@ -1,5 +1,3 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 botocore==1.34.143 boto3==1.34.143 diff --git a/contract-tests/images/applications/django/requirements.txt b/contract-tests/images/applications/django/requirements.txt index 9b54a7736..84dfdeabb 100644 --- a/contract-tests/images/applications/django/requirements.txt +++ b/contract-tests/images/applications/django/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 django==5.0.11 diff --git a/contract-tests/images/applications/mysql-connector/requirements.txt b/contract-tests/images/applications/mysql-connector/requirements.txt index 9ca44d2e4..f285dcb1f 100644 --- a/contract-tests/images/applications/mysql-connector/requirements.txt +++ b/contract-tests/images/applications/mysql-connector/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 mysql-connector-python~=9.1.0 diff --git a/contract-tests/images/applications/mysqlclient/requirements.txt b/contract-tests/images/applications/mysqlclient/requirements.txt index 49c6b70f3..933e606b4 100644 --- a/contract-tests/images/applications/mysqlclient/requirements.txt +++ b/contract-tests/images/applications/mysqlclient/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 mysqlclient==2.2.4 diff --git a/contract-tests/images/applications/psycopg2/requirements.txt b/contract-tests/images/applications/psycopg2/requirements.txt index f2d278475..8786aff35 100644 --- a/contract-tests/images/applications/psycopg2/requirements.txt +++ b/contract-tests/images/applications/psycopg2/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 psycopg2==2.9.9 diff --git a/contract-tests/images/applications/pymysql/requirements.txt b/contract-tests/images/applications/pymysql/requirements.txt index ddda9b1fe..8ba76defb 100644 --- a/contract-tests/images/applications/pymysql/requirements.txt +++ b/contract-tests/images/applications/pymysql/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 pymysql==1.1.1 diff --git a/contract-tests/images/applications/requests/requirements.txt b/contract-tests/images/applications/requests/requirements.txt index 369049d22..700b31404 100644 --- a/contract-tests/images/applications/requests/requirements.txt +++ b/contract-tests/images/applications/requests/requirements.txt @@ -1,4 +1,2 @@ -opentelemetry-distro==0.46b0 -opentelemetry-exporter-otlp-proto-grpc==1.25.0 typing-extensions==4.12.2 requests~=2.0 diff --git a/contract-tests/images/mock-collector/pyproject.toml b/contract-tests/images/mock-collector/pyproject.toml index 422e2a5b1..42e13c868 100644 --- a/contract-tests/images/mock-collector/pyproject.toml +++ b/contract-tests/images/mock-collector/pyproject.toml @@ -11,9 +11,9 @@ requires-python = ">=3.9" dependencies = [ "grpcio ~= 1.66.0", - "opentelemetry-proto==1.25.0", - "opentelemetry-sdk==1.25.0", - "protobuf==4.25.2", + "opentelemetry-proto==1.33.1", + "opentelemetry-sdk==1.33.1", + "protobuf==5.26.1", "typing-extensions==4.12.2" ] diff --git a/contract-tests/images/mock-collector/requirements.txt b/contract-tests/images/mock-collector/requirements.txt index a0c5454cd..12e69148b 100644 --- a/contract-tests/images/mock-collector/requirements.txt +++ b/contract-tests/images/mock-collector/requirements.txt @@ -1,5 +1,5 @@ grpcio==1.66.2 -opentelemetry-proto==1.25.0 -opentelemetry-sdk==1.25.0 -protobuf==4.25.2 +opentelemetry-proto==1.33.1 +opentelemetry-sdk==1.33.1 +protobuf==5.26.1 typing-extensions==4.12.2 diff --git a/contract-tests/tests/pyproject.toml b/contract-tests/tests/pyproject.toml index 0df6f6a1c..5c2895fab 100644 --- a/contract-tests/tests/pyproject.toml +++ b/contract-tests/tests/pyproject.toml @@ -10,8 +10,8 @@ license = "Apache-2.0" requires-python = ">=3.9" dependencies = [ - "opentelemetry-proto==1.25.0", - "opentelemetry-sdk==1.25.0", + "opentelemetry-proto==1.33.1", + "opentelemetry-sdk==1.33.1", "testcontainers==3.7.1", "grpcio==1.66.2", "docker==7.1.0", diff --git a/contract-tests/tests/test/amazon/botocore/botocore_test.py b/contract-tests/tests/test/amazon/botocore/botocore_test.py index ed04c9514..549ec3f50 100644 --- a/contract-tests/tests/test/amazon/botocore/botocore_test.py +++ b/contract-tests/tests/test/amazon/botocore/botocore_test.py @@ -440,7 +440,7 @@ def test_bedrock_runtime_invoke_model_amazon_titan(self): _GEN_AI_USAGE_INPUT_TOKENS: 15, _GEN_AI_USAGE_OUTPUT_TOKENS: 13, }, - span_name="Bedrock Runtime.InvokeModel", + span_name="text_completion amazon.titan-text-premier-v1:0", ) def test_bedrock_runtime_invoke_model_amazon_nova(self): @@ -458,6 +458,7 @@ def test_bedrock_runtime_invoke_model_amazon_nova(self): cloudformation_primary_identifier="amazon.nova-pro-v1:0", request_specific_attributes={ _GEN_AI_REQUEST_MODEL: "amazon.nova-pro-v1:0", + _GEN_AI_SYSTEM: "aws.bedrock", _GEN_AI_REQUEST_MAX_TOKENS: 800, _GEN_AI_REQUEST_TEMPERATURE: 0.9, _GEN_AI_REQUEST_TOP_P: 0.7, @@ -467,7 +468,7 @@ def test_bedrock_runtime_invoke_model_amazon_nova(self): _GEN_AI_USAGE_INPUT_TOKENS: 432, _GEN_AI_USAGE_OUTPUT_TOKENS: 681, }, - span_name="Bedrock Runtime.InvokeModel", + span_name="chat amazon.nova-pro-v1:0", ) def test_bedrock_runtime_invoke_model_anthropic_claude(self): @@ -495,7 +496,7 @@ def test_bedrock_runtime_invoke_model_anthropic_claude(self): _GEN_AI_USAGE_INPUT_TOKENS: 15, _GEN_AI_USAGE_OUTPUT_TOKENS: 13, }, - span_name="Bedrock Runtime.InvokeModel", + span_name="chat anthropic.claude-v2:1", ) def test_bedrock_runtime_invoke_model_meta_llama(self): @@ -523,7 +524,7 @@ def test_bedrock_runtime_invoke_model_meta_llama(self): _GEN_AI_USAGE_INPUT_TOKENS: 31, _GEN_AI_USAGE_OUTPUT_TOKENS: 49, }, - span_name="Bedrock Runtime.InvokeModel", + span_name="chat meta.llama2-13b-chat-v1", ) def test_bedrock_runtime_invoke_model_cohere_command(self): @@ -553,35 +554,7 @@ def test_bedrock_runtime_invoke_model_cohere_command(self): ), _GEN_AI_USAGE_OUTPUT_TOKENS: math.ceil(len("test-generation-text") / 6), }, - span_name="Bedrock Runtime.InvokeModel", - ) - - def test_bedrock_runtime_invoke_model_ai21_jamba(self): - self.do_test_requests( - "bedrock/invokemodel/invoke-model/ai21.jamba-1-5-large-v1:0", - "GET", - 200, - 0, - 0, - rpc_service="Bedrock Runtime", - remote_service="AWS::BedrockRuntime", - remote_operation="InvokeModel", - remote_resource_type="AWS::Bedrock::Model", - remote_resource_identifier="ai21.jamba-1-5-large-v1:0", - cloudformation_primary_identifier="ai21.jamba-1-5-large-v1:0", - request_specific_attributes={ - _GEN_AI_REQUEST_MODEL: "ai21.jamba-1-5-large-v1:0", - _GEN_AI_SYSTEM: "aws.bedrock", - _GEN_AI_REQUEST_MAX_TOKENS: 512, - _GEN_AI_REQUEST_TEMPERATURE: 0.6, - _GEN_AI_REQUEST_TOP_P: 0.8, - }, - response_specific_attributes={ - _GEN_AI_RESPONSE_FINISH_REASONS: ["stop"], - _GEN_AI_USAGE_INPUT_TOKENS: 21, - _GEN_AI_USAGE_OUTPUT_TOKENS: 24, - }, - span_name="Bedrock Runtime.InvokeModel", + span_name="chat cohere.command-r-v1:0", ) def test_bedrock_runtime_invoke_model_mistral(self): @@ -611,7 +584,7 @@ def test_bedrock_runtime_invoke_model_mistral(self): ), _GEN_AI_USAGE_OUTPUT_TOKENS: math.ceil(len("test-output-text") / 6), }, - span_name="Bedrock Runtime.InvokeModel", + span_name="chat mistral.mistral-7b-instruct-v0:2", ) def test_bedrock_get_guardrail(self):