Skip to content

Commit e4f8dd2

Browse files
committed
Add async support
1 parent 5cfbfb1 commit e4f8dd2

File tree

4 files changed

+80
-47
lines changed

4 files changed

+80
-47
lines changed

core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import software.amazon.awssdk.annotations.SdkProtectedApi;
2121
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
2222
import software.amazon.awssdk.core.interceptor.Context;
23+
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
2324
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
2425
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
2526
import software.amazon.awssdk.http.SdkHttpRequest;
@@ -34,29 +35,38 @@ public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
3435
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
3536
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
3637
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TraceId";
38+
protected static final ExecutionAttribute<String> CACHED_TRACE_ID = new ExecutionAttribute<>("CachedTraceId");
39+
40+
@Override
41+
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
42+
String traceId = MDC.get(CONCURRENT_TRACE_ID_KEY);
43+
if (traceId != null) {
44+
executionAttributes.putAttribute(CACHED_TRACE_ID, traceId);
45+
}
46+
}
3747

3848
@Override
3949
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
4050
Optional<String> traceIdHeader = traceIdHeader(context);
4151
if (!traceIdHeader.isPresent()) {
42-
Optional<String> lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
43-
Optional<String> traceId = traceId();
52+
Optional<String> lambdaFunctionName = lambdaFunctionNameEnvironmentVariable();
53+
Optional<String> traceId = traceId(executionAttributes);
4454

45-
if (lambdafunctionName.isPresent() && traceId.isPresent()) {
55+
if (lambdaFunctionName.isPresent() && traceId.isPresent()) {
4656
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
4757
}
4858
}
49-
5059
return context.httpRequest();
5160
}
5261

5362
private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
5463
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
5564
}
5665

57-
private Optional<String> traceId() {
58-
if (TracingSystemSetting.AWS_LAMBDA_MAX_CONCURRENCY.getStringValue().isPresent()) {
59-
return Optional.ofNullable(MDC.get(CONCURRENT_TRACE_ID_KEY));
66+
private Optional<String> traceId(ExecutionAttributes executionAttributes) {
67+
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(CACHED_TRACE_ID));
68+
if (traceId.isPresent()) {
69+
return traceId;
6070
}
6171
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
6272
}

core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/interceptor/TracingSystemSetting.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
@SdkInternalApi
2525
public enum TracingSystemSetting implements SystemSetting {
2626
// See: https://github.com/aws/aws-xray-sdk-java/issues/251
27-
_X_AMZN_TRACE_ID("com.amazonaws.xray.traceHeader", null),
28-
// Environment variable to detect Lambda multi concurrency mode ("elevator"). This value is set by the Lambda runtime.
29-
AWS_LAMBDA_MAX_CONCURRENCY("aws.lambda.maxConcurrency", null);
27+
_X_AMZN_TRACE_ID("com.amazonaws.xray.traceHeader", null);
3028

3129
private final String systemProperty;
3230
private final String defaultValue;

core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,17 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithMdc_shouldAddTraceIdHe
117117
EnvironmentVariableHelper.run(env -> {
118118
resetRelevantEnvVars(env);
119119
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
120-
env.set("AWS_LAMBDA_MAX_CONCURRENCY", "10");
121-
122120
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
123121

124122
try {
123+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
124+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
125+
126+
interceptor.beforeExecution(null, executionAttributes);
125127
Context.ModifyHttpRequest context = context();
126-
assertThat(modifyHttpRequest(context).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
128+
129+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
130+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
127131
} finally {
128132
MDC.remove("AWS_LAMBDA_X_TraceId");
129133
}
@@ -135,65 +139,45 @@ public void modifyHttpRequest_whenMultiConcurrencyModeWithBothMdcAndSystemProper
135139
EnvironmentVariableHelper.run(env -> {
136140
resetRelevantEnvVars(env);
137141
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
138-
env.set("AWS_LAMBDA_MAX_CONCURRENCY", "10");
139142

140143
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
141144
Properties props = System.getProperties();
142145
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");
143146

144147
try {
148+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
149+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
150+
151+
interceptor.beforeExecution(null, executionAttributes);
152+
145153
Context.ModifyHttpRequest context = context();
146-
assertThat(modifyHttpRequest(context).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
154+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
155+
156+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
147157
} finally {
148158
MDC.remove("AWS_LAMBDA_X_TraceId");
149159
props.remove("com.amazonaws.xray.traceHeader");
150160
}
151161
});
152162
}
153163

154-
@Test
155-
public void modifyHttpRequest_whenMultiConcurrencyModeWithEmptyMdc_shouldNotAddHeader() {
156-
EnvironmentVariableHelper.run(env -> {
157-
resetRelevantEnvVars(env);
158-
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
159-
env.set("AWS_LAMBDA_MAX_CONCURRENCY", "10");
160-
161-
MDC.clear();
162-
163-
Context.ModifyHttpRequest context = context();
164-
assertThat(modifyHttpRequest(context)).isSameAs(context.httpRequest());
165-
});
166-
}
167-
168164
@Test
169165
public void modifyHttpRequest_whenNotInLambdaEnvironmentWithMdc_shouldNotAddHeader() {
170166
EnvironmentVariableHelper.run(env -> {
171167
resetRelevantEnvVars(env);
172-
env.set("AWS_LAMBDA_MAX_CONCURRENCY", "10");
173168

174169
MDC.put("AWS_LAMBDA_X_TraceId", "should-be-ignored");
175170

176171
try {
177-
Context.ModifyHttpRequest context = context();
178-
assertThat(modifyHttpRequest(context)).isSameAs(context.httpRequest());
179-
} finally {
180-
MDC.remove("AWS_LAMBDA_X_TraceId");
181-
}
182-
});
183-
}
172+
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
173+
ExecutionAttributes executionAttributes = new ExecutionAttributes();
184174

185-
@Test
186-
public void modifyHttpRequest_whenConcurrencyModeIsEmptyString_shouldUseMdcValue() {
187-
EnvironmentVariableHelper.run(env -> {
188-
resetRelevantEnvVars(env);
189-
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
190-
env.set("AWS_LAMBDA_MAX_CONCURRENCY", "");
175+
interceptor.beforeExecution(null, executionAttributes);
191176

192-
MDC.put("AWS_LAMBDA_X_TraceId", "empty-string-test");
193-
194-
try {
195177
Context.ModifyHttpRequest context = context();
196-
assertThat(modifyHttpRequest(context).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("empty-string-test");
178+
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
179+
180+
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
197181
} finally {
198182
MDC.remove("AWS_LAMBDA_X_TraceId");
199183
}

test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
1919

20+
import java.util.List;
2021
import org.junit.jupiter.api.Test;
22+
import org.slf4j.MDC;
2123
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
2224
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
2325
import software.amazon.awssdk.http.AbortableInputStream;
2426
import software.amazon.awssdk.http.HttpExecuteResponse;
27+
import software.amazon.awssdk.http.SdkHttpRequest;
2528
import software.amazon.awssdk.http.SdkHttpResponse;
2629
import software.amazon.awssdk.regions.Region;
30+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
2731
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
2832
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
33+
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
2934
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
3035
import software.amazon.awssdk.utils.StringInputStream;
3136

@@ -56,4 +61,40 @@ public void traceIdInterceptorIsEnabled() {
5661
}
5762
});
5863
}
64+
65+
@Test
66+
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
67+
EnvironmentVariableHelper.run(env -> {
68+
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
69+
MDC.put("AWS_LAMBDA_X_TraceId", "mdc-trace-123");
70+
71+
try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
72+
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
73+
.region(Region.US_WEST_2)
74+
.credentialsProvider(AnonymousCredentialsProvider.create())
75+
.httpClient(mockHttpClient)
76+
.build()) {
77+
78+
mockHttpClient.stubResponses(
79+
HttpExecuteResponse.builder()
80+
.response(SdkHttpResponse.builder().statusCode(500).build())
81+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
82+
.build(),
83+
HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
84+
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
85+
.build());
86+
87+
client.allTypes().join();
88+
89+
List<SdkHttpRequest> requests = mockHttpClient.getRequests();
90+
assertThat(requests).hasSize(2);
91+
92+
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
93+
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
94+
95+
} finally {
96+
MDC.clear();
97+
}
98+
});
99+
}
59100
}

0 commit comments

Comments
 (0)