Skip to content

Commit e0814cb

Browse files
committed
Add afterExecution/onExecutionFailure hooks to restore MDC trace ID
1 parent d696756 commit e0814cb

File tree

3 files changed

+120
-13
lines changed

3 files changed

+120
-13
lines changed

core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
3535
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
3636
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
37-
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TraceId";
38-
private static final ExecutionAttribute<String> CACHED_TRACE_ID = new ExecutionAttribute<>("CachedTraceId");
37+
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
38+
private static final ExecutionAttribute<String> TRACE_ID = new ExecutionAttribute<>("TraceId");
3939

4040
@Override
4141
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
4242
String traceId = MDC.get(CONCURRENT_TRACE_ID_KEY);
4343
if (traceId != null) {
44-
executionAttributes.putAttribute(CACHED_TRACE_ID, traceId);
44+
executionAttributes.putAttribute(TRACE_ID, traceId);
4545
}
4646
}
4747

@@ -59,12 +59,26 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, Execu
5959
return context.httpRequest();
6060
}
6161

62+
@Override
63+
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
64+
saveTraceId(executionAttributes);
65+
}
66+
67+
@Override
68+
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
69+
saveTraceId(executionAttributes);
70+
}
71+
72+
private static void saveTraceId(ExecutionAttributes executionAttributes) {
73+
MDC.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
74+
}
75+
6276
private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
6377
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
6478
}
6579

6680
private Optional<String> traceId(ExecutionAttributes executionAttributes) {
67-
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(CACHED_TRACE_ID));
81+
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
6882
if (traceId.isPresent()) {
6983
return traceId;
7084
}

core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithMdc_shouldAddTraceIdHe
117117
EnvironmentVariableHelper.run(env -> {
118118
resetRelevantEnvVars(env);
119119
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
120-
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
120+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
121121

122122
try {
123123
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
@@ -129,7 +129,7 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithMdc_shouldAddTraceIdHe
129129
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
130130
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
131131
} finally {
132-
MDC.remove("AWS_LAMBDA_X_TraceId");
132+
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
133133
}
134134
});
135135
}
@@ -140,7 +140,7 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithBothMdcAndSystemProper
140140
resetRelevantEnvVars(env);
141141
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
142142

143-
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
143+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
144144
Properties props = System.getProperties();
145145
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");
146146

@@ -155,7 +155,7 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithBothMdcAndSystemProper
155155

156156
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
157157
} finally {
158-
MDC.remove("AWS_LAMBDA_X_TraceId");
158+
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
159159
props.remove("com.amazonaws.xray.traceHeader");
160160
}
161161
});
@@ -166,7 +166,7 @@ public void modifyHttpRequest_whenNotInLambdaEnvironmentWithMdc_shouldNotAddHead
166166
EnvironmentVariableHelper.run(env -> {
167167
resetRelevantEnvVars(env);
168168

169-
MDC.put("AWS_LAMBDA_X_TraceId", "should-be-ignored");
169+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");
170170

171171
try {
172172
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
@@ -179,7 +179,7 @@ public void modifyHttpRequest_whenNotInLambdaEnvironmentWithMdc_shouldNotAddHead
179179

180180
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
181181
} finally {
182-
MDC.remove("AWS_LAMBDA_X_TraceId");
182+
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
183183
}
184184
});
185185
}
@@ -206,6 +206,5 @@ private SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context) {
206206
private void resetRelevantEnvVars(EnvironmentVariableHelper env) {
207207
env.remove("AWS_LAMBDA_FUNCTION_NAME");
208208
env.remove("_X_AMZN_TRACE_ID");
209-
env.remove("AWS_LAMBDA_MAX_CONCURRENCY");
210209
}
211210
}

test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
/**
3838
* Verifies that the {@link TraceIdExecutionInterceptor} is actually wired up for AWS services.
3939
*/
40-
public class TraceIdTest {
40+
public class
41+
TraceIdTest {
4142
@Test
4243
public void traceIdInterceptorIsEnabled() {
4344
EnvironmentVariableHelper.run(env -> {
@@ -66,7 +67,7 @@ public void traceIdInterceptorIsEnabled() {
6667
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
6768
EnvironmentVariableHelper.run(env -> {
6869
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
69-
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
70+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
7071

7172
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
7273
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
@@ -76,6 +77,10 @@ public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
7677
.build()) {
7778

7879
mockHttpClient.stubResponses(
80+
HttpExecuteResponse.builder()
81+
.response(SdkHttpResponse.builder().statusCode(500).build())
82+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
83+
.build(),
7984
HttpExecuteResponse.builder()
8085
.response(SdkHttpResponse.builder().statusCode(500).build())
8186
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
@@ -87,6 +92,94 @@ public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
8792
client.allTypes().join();
8893

8994
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
95+
assertThat(requests).hasSize(3);
96+
97+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
98+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
99+
assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
100+
101+
} finally {
102+
MDC.clear();
103+
}
104+
});
105+
}
106+
107+
@Test
108+
public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
109+
EnvironmentVariableHelper.run(env -> {
110+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
111+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
112+
113+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
114+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
115+
.region(Region.US_WEST_2)
116+
.credentialsProvider(AnonymousCredentialsProvider.create())
117+
.httpClient(mockHttpClient)
118+
.build()) {
119+
120+
mockHttpClient.stubResponses(
121+
HttpExecuteResponse.builder()
122+
.response(SdkHttpResponse.builder().statusCode(200).build())
123+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
124+
.build(),
125+
HttpExecuteResponse.builder()
126+
.response(SdkHttpResponse.builder().statusCode(200).build())
127+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
128+
.build()
129+
);
130+
131+
client.allTypes()
132+
.thenRun(() -> {
133+
String traceId = MDC.get("AWS_LAMBDA_X_TRACE_ID");
134+
client.allTypes().join();
135+
})
136+
.join();
137+
138+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
139+
140+
assertThat(requests).hasSize(2);
141+
142+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
143+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
144+
145+
} finally {
146+
MDC.clear();
147+
}
148+
});
149+
}
150+
151+
@Test
152+
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
153+
EnvironmentVariableHelper.run(env -> {
154+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
155+
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
156+
157+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
158+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
159+
.region(Region.US_WEST_2)
160+
.credentialsProvider(AnonymousCredentialsProvider.create())
161+
.httpClient(mockHttpClient)
162+
.build()) {
163+
164+
mockHttpClient.stubResponses(
165+
HttpExecuteResponse.builder()
166+
.response(SdkHttpResponse.builder().statusCode(400).build())
167+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
168+
.build(),
169+
HttpExecuteResponse.builder()
170+
.response(SdkHttpResponse.builder().statusCode(200).build())
171+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
172+
.build()
173+
);
174+
175+
client.allTypes()
176+
.exceptionally(throwable -> {
177+
client.allTypes().join();
178+
return null;
179+
}).join();
180+
181+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
182+
90183
assertThat(requests).hasSize(2);
91184

92185
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
@@ -98,3 +191,4 @@ public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
98191
});
99192
}
100193
}
194+

0 commit comments

Comments
 (0)