Skip to content

Commit 8ce45de

Browse files
committed
Add support for concurrent trace id propagation (#6403)
Add support for concurrent trace id propagation
1 parent f0d71d5 commit 8ce45de

File tree

4 files changed

+300
-5
lines changed

4 files changed

+300
-5
lines changed

core/aws-core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@
113113
<groupId>software.amazon.eventstream</groupId>
114114
<artifactId>eventstream</artifactId>
115115
</dependency>
116+
<dependency>
117+
<groupId>software.amazon.awssdk</groupId>
118+
<artifactId>utils-lite</artifactId>
119+
<version>${awsjavasdk.version}</version>
120+
</dependency>
116121

117122
<dependency>
118123
<groupId>software.amazon.awssdk</groupId>

core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import software.amazon.awssdk.annotations.SdkProtectedApi;
2020
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
2121
import software.amazon.awssdk.core.interceptor.Context;
22+
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
2223
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
2324
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
2425
import software.amazon.awssdk.http.SdkHttpRequest;
2526
import software.amazon.awssdk.utils.SystemSetting;
27+
import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
2628

2729
/**
2830
* The {@code TraceIdExecutionInterceptor} copies the trace details to the {@link #TRACE_ID_HEADER} header, assuming we seem to
@@ -32,27 +34,57 @@
3234
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
3335
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
3436
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
37+
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
38+
private static final ExecutionAttribute<String> TRACE_ID = new ExecutionAttribute<>("TraceId");
39+
40+
@Override
41+
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
42+
String traceId = SdkInternalThreadLocal.get(CONCURRENT_TRACE_ID_KEY);
43+
if (traceId != null) {
44+
executionAttributes.putAttribute(TRACE_ID, traceId);
45+
}
46+
}
3547

3648
@Override
3749
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
3850
Optional<String> traceIdHeader = traceIdHeader(context);
3951
if (!traceIdHeader.isPresent()) {
4052
Optional<String> lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
41-
Optional<String> traceId = traceId();
53+
Optional<String> traceId = traceId(executionAttributes);
4254

4355
if (lambdafunctionName.isPresent() && traceId.isPresent()) {
4456
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
4557
}
4658
}
47-
4859
return context.httpRequest();
4960
}
5061

62+
@Override
63+
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
64+
saveTraceId(executionAttributes);
65+
}
66+
67+
@Override
68+
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
69+
saveTraceId(executionAttributes);
70+
}
71+
72+
private static void saveTraceId(ExecutionAttributes executionAttributes) {
73+
String traceId = executionAttributes.getAttribute(TRACE_ID);
74+
if (traceId != null) {
75+
SdkInternalThreadLocal.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
76+
}
77+
}
78+
5179
private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
5280
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
5381
}
5482

55-
private Optional<String> traceId() {
83+
private Optional<String> traceId(ExecutionAttributes executionAttributes) {
84+
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
85+
if (traceId.isPresent()) {
86+
return traceId;
87+
}
5688
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
5789
}
5890

@@ -61,4 +93,4 @@ private Optional<String> lambdaFunctionNameEnvironmentVariable() {
6193
return SystemSetting.getStringValueFromEnvironmentVariable(LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE);
6294
// CHECKSTYLE:ON
6395
}
64-
}
96+
}

core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import software.amazon.awssdk.http.SdkHttpMethod;
2929
import software.amazon.awssdk.http.SdkHttpRequest;
3030
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
31+
import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
3132

3233
public class TraceIdExecutionInterceptorTest {
3334
@Test
@@ -111,6 +112,78 @@ public void headerNotAddedIfNoTraceIdEnvVar() {
111112
});
112113
}
113114

115+
@Test
116+
public void modifyHttpRequest_whenMultiConcurrencyModeWithInternalThreadLocal_shouldAddTraceIdHeader() {
117+
EnvironmentVariableHelper.run(env -> {
118+
resetRelevantEnvVars(env);
119+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
120+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
121+
122+
try {
123+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
124+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
125+
126+
interceptor.beforeExecution(null, executionAttributes);
127+
Context.ModifyHttpRequest context = context();
128+
129+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
130+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
131+
} finally {
132+
SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
133+
}
134+
});
135+
}
136+
137+
@Test
138+
public void modifyHttpRequest_whenMultiConcurrencyModeWithBothInternalThreadLocalAndSystemProperty_shouldUseInternalThreadLocalValue() {
139+
EnvironmentVariableHelper.run(env -> {
140+
resetRelevantEnvVars(env);
141+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
142+
143+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
144+
Properties props = System.getProperties();
145+
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");
146+
147+
try {
148+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
149+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
150+
151+
interceptor.beforeExecution(null, executionAttributes);
152+
153+
Context.ModifyHttpRequest context = context();
154+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
155+
156+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
157+
} finally {
158+
SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
159+
props.remove("com.amazonaws.xray.traceHeader");
160+
}
161+
});
162+
}
163+
164+
@Test
165+
public void modifyHttpRequest_whenNotInLambdaEnvironmentWithInternalThreadLocal_shouldNotAddHeader() {
166+
EnvironmentVariableHelper.run(env -> {
167+
resetRelevantEnvVars(env);
168+
169+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");
170+
171+
try {
172+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
173+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
174+
175+
interceptor.beforeExecution(null, executionAttributes);
176+
177+
Context.ModifyHttpRequest context = context();
178+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
179+
180+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
181+
} finally {
182+
SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
183+
}
184+
});
185+
}
186+
114187
private Context.ModifyHttpRequest context() {
115188
return context(SdkHttpRequest.builder()
116189
.uri(URI.create("https://localhost"))

test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,25 @@
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
1919

20+
import java.util.List;
2021
import org.junit.jupiter.api.Test;
2122
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
2223
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
24+
import software.amazon.awssdk.core.interceptor.Context;
25+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
26+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
2327
import software.amazon.awssdk.http.AbortableInputStream;
2428
import software.amazon.awssdk.http.HttpExecuteResponse;
29+
import software.amazon.awssdk.http.SdkHttpRequest;
2530
import software.amazon.awssdk.http.SdkHttpResponse;
2631
import software.amazon.awssdk.regions.Region;
32+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
2733
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
2834
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
35+
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
2936
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
3037
import software.amazon.awssdk.utils.StringInputStream;
38+
import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
3139

3240
/**
3341
* Verifies that the {@link TraceIdExecutionInterceptor} is actually wired up for AWS services.
@@ -56,4 +64,181 @@ public void traceIdInterceptorIsEnabled() {
5664
}
5765
});
5866
}
59-
}
67+
68+
@Test
69+
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
70+
EnvironmentVariableHelper.run(env -> {
71+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
72+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
73+
74+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
75+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
76+
.region(Region.US_WEST_2)
77+
.credentialsProvider(AnonymousCredentialsProvider.create())
78+
.httpClient(mockHttpClient)
79+
.build()) {
80+
81+
mockHttpClient.stubResponses(
82+
HttpExecuteResponse.builder()
83+
.response(SdkHttpResponse.builder().statusCode(500).build())
84+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
85+
.build(),
86+
HttpExecuteResponse.builder()
87+
.response(SdkHttpResponse.builder().statusCode(500).build())
88+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
89+
.build(),
90+
HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
91+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
92+
.build());
93+
94+
client.allTypes().join();
95+
96+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
97+
assertThat(requests).hasSize(3);
98+
99+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
100+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
101+
assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
102+
103+
} finally {
104+
SdkInternalThreadLocal.clear();
105+
}
106+
});
107+
}
108+
109+
@Test
110+
public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
111+
EnvironmentVariableHelper.run(env -> {
112+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
113+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
114+
115+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
116+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
117+
.region(Region.US_WEST_2)
118+
.credentialsProvider(AnonymousCredentialsProvider.create())
119+
.httpClient(mockHttpClient)
120+
.build()) {
121+
122+
mockHttpClient.stubResponses(
123+
HttpExecuteResponse.builder()
124+
.response(SdkHttpResponse.builder().statusCode(200).build())
125+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
126+
.build(),
127+
HttpExecuteResponse.builder()
128+
.response(SdkHttpResponse.builder().statusCode(200).build())
129+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
130+
.build()
131+
);
132+
133+
client.allTypes()
134+
.thenRun(() -> {
135+
client.allTypes().join();
136+
})
137+
.join();
138+
139+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
140+
141+
assertThat(requests).hasSize(2);
142+
143+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
144+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
145+
146+
} finally {
147+
SdkInternalThreadLocal.clear();
148+
}
149+
});
150+
}
151+
152+
@Test
153+
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
154+
EnvironmentVariableHelper.run(env -> {
155+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
156+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
157+
158+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
159+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
160+
.region(Region.US_WEST_2)
161+
.credentialsProvider(AnonymousCredentialsProvider.create())
162+
.httpClient(mockHttpClient)
163+
.build()) {
164+
165+
mockHttpClient.stubResponses(
166+
HttpExecuteResponse.builder()
167+
.response(SdkHttpResponse.builder().statusCode(400).build())
168+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
169+
.build(),
170+
HttpExecuteResponse.builder()
171+
.response(SdkHttpResponse.builder().statusCode(200).build())
172+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
173+
.build()
174+
);
175+
176+
client.allTypes()
177+
.exceptionally(throwable -> {
178+
client.allTypes().join();
179+
return null;
180+
}).join();
181+
182+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
183+
184+
assertThat(requests).hasSize(2);
185+
186+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
187+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
188+
189+
} finally {
190+
SdkInternalThreadLocal.clear();
191+
}
192+
});
193+
}
194+
195+
@Test
196+
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution() {
197+
EnvironmentVariableHelper.run(env -> {
198+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
199+
SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
200+
201+
ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor() {
202+
private boolean hasThrown = false;
203+
204+
@Override
205+
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
206+
if (!hasThrown) {
207+
hasThrown = true;
208+
throw new RuntimeException("failing in pre execution");
209+
}
210+
}
211+
};
212+
213+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
214+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
215+
.region(Region.US_WEST_2)
216+
.credentialsProvider(AnonymousCredentialsProvider.create())
217+
.overrideConfiguration(o -> o.addExecutionInterceptor(throwingInterceptor))
218+
.httpClient(mockHttpClient)
219+
.build()) {
220+
221+
mockHttpClient.stubResponses(
222+
HttpExecuteResponse.builder()
223+
.response(SdkHttpResponse.builder().statusCode(200).build())
224+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
225+
.build()
226+
);
227+
228+
client.allTypes()
229+
.exceptionally(throwable -> {
230+
client.allTypes().join();
231+
return null;
232+
}).join();
233+
234+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
235+
236+
assertThat(requests).hasSize(1);
237+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
238+
239+
} finally {
240+
SdkInternalThreadLocal.clear();
241+
}
242+
});
243+
}
244+
}

0 commit comments

Comments
 (0)