diff --git a/propagator/opentelemetry-propagator-aws-xray/CHANGELOG.md b/propagator/opentelemetry-propagator-aws-xray/CHANGELOG.md index bdbcdef927..c3aa07d891 100644 --- a/propagator/opentelemetry-propagator-aws-xray/CHANGELOG.md +++ b/propagator/opentelemetry-propagator-aws-xray/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update `opentelemetry-api` version to 1.16 ([#2961](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2961)) +- aws-xray-propagator: ensure trace state is not overwritten + ([#3774](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3774)) ## Version 1.0.2 (2024-08-05) diff --git a/propagator/opentelemetry-propagator-aws-xray/src/opentelemetry/propagators/aws/aws_xray_propagator.py b/propagator/opentelemetry-propagator-aws-xray/src/opentelemetry/propagators/aws/aws_xray_propagator.py index d9b99f35ca..77ecbbf12e 100644 --- a/propagator/opentelemetry-propagator-aws-xray/src/opentelemetry/propagators/aws/aws_xray_propagator.py +++ b/propagator/opentelemetry-propagator-aws-xray/src/opentelemetry/propagators/aws/aws_xray_propagator.py @@ -144,12 +144,17 @@ def extract( if sampled: options |= trace.TraceFlags.SAMPLED + try: + tracestate = trace.get_current_span(context=context).get_span_context().trace_state + except AttributeError: + tracestate = trace.TraceState() + span_context = trace.SpanContext( trace_id=trace_id, span_id=span_id, is_remote=True, trace_flags=trace.TraceFlags(options), - trace_state=trace.TraceState(), + trace_state=tracestate, ) if not span_context.is_valid: diff --git a/propagator/opentelemetry-propagator-aws-xray/tests/test_aws_xray_propagator.py b/propagator/opentelemetry-propagator-aws-xray/tests/test_aws_xray_propagator.py index 2fb8a4925c..a39caf54af 100644 --- a/propagator/opentelemetry-propagator-aws-xray/tests/test_aws_xray_propagator.py +++ b/propagator/opentelemetry-propagator-aws-xray/tests/test_aws_xray_propagator.py @@ -335,3 +335,42 @@ def test_fields(self, mock_trace): self.assertEqual( AwsXRayPropagatorTest.XRAY_PROPAGATOR.fields, inject_fields ) + + def test_extract_trace_state_from_context(self): + """Test that extract properly propagates the trace state extracted by other propagators.""" + context_with_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( + CaseInsensitiveDict( + { + TRACE_HEADER_KEY: "Root=1-8a3c60f7-d188f8fa79d48a391a778fa6;Parent=53995c3f42cd8ad8;Sampled=0", + } + ), + context=set_span_in_context( + trace_api.NonRecordingSpan( + SpanContext( + int(TRACE_ID_BASE16, 16), + int(SPAN_ID_BASE16, 16), + True, + DEFAULT_TRACE_OPTIONS, + TraceState([("foo", "bar"), ("baz", "qux")]) + ) + ) + ) + ) + + extracted_span_context = get_nested_span_context(context_with_extracted) + expected_trace_state = TraceState([("foo", "bar"), ("baz", "qux")]) + + self.assertEqual(extracted_span_context.trace_state, expected_trace_state) + + def test_extract_no_trace_state_from_context(self): + """Test that extract defaults to an empty trace state correctly.""" + context_with_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( + CaseInsensitiveDict( + { + TRACE_HEADER_KEY: "Root=1-8a3c60f7-d188f8fa79d48a391a778fa6;Parent=53995c3f42cd8ad8;Sampled=0", + } + ) + ) + + extracted_span_context = get_nested_span_context(context_with_extracted) + self.assertEqual(extracted_span_context.trace_state, TraceState([]))