Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
package software.amazon.awssdk.awscore.interceptor;

import java.util.Optional;
import org.slf4j.MDC;
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.SdkHttpRequest;
Expand All @@ -34,57 +32,27 @@
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
private static final ExecutionAttribute<String> TRACE_ID = new ExecutionAttribute<>("TraceId");

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
String traceId = MDC.get(CONCURRENT_TRACE_ID_KEY);
if (traceId != null) {
executionAttributes.putAttribute(TRACE_ID, traceId);
}
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
Optional<String> traceIdHeader = traceIdHeader(context);
if (!traceIdHeader.isPresent()) {
Optional<String> lambdaFunctionName = lambdaFunctionNameEnvironmentVariable();
Optional<String> traceId = traceId(executionAttributes);
Optional<String> lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
Optional<String> traceId = traceId();

if (lambdaFunctionName.isPresent() && traceId.isPresent()) {
if (lambdafunctionName.isPresent() && traceId.isPresent()) {
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
}
}
return context.httpRequest();
}

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

private static void saveTraceId(ExecutionAttributes executionAttributes) {
String traceId = executionAttributes.getAttribute(TRACE_ID);
if (traceId != null) {
MDC.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
}
return context.httpRequest();
}

private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
}

private Optional<String> traceId(ExecutionAttributes executionAttributes) {
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
if (traceId.isPresent()) {
return traceId;
}
private Optional<String> traceId() {
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Properties;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.MDC;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
Expand Down Expand Up @@ -112,78 +111,6 @@ public void headerNotAddedIfNoTraceIdEnvVar() {
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithMdc_shouldAddTraceIdHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);
Context.ModifyHttpRequest context = context();

SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithBothMdcAndSystemProperty_shouldUseMdcValue() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");

MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
Properties props = System.getProperties();
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
props.remove("com.amazonaws.xray.traceHeader");
}
});
}

@Test
public void modifyHttpRequest_whenNotInLambdaEnvironmentWithMdc_shouldNotAddHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);

MDC.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

private Context.ModifyHttpRequest context() {
return context(SdkHttpRequest.builder()
.uri(URI.create("https://localhost"))
Expand Down
5 changes: 2 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,11 @@
<includeModule>polly</includeModule>
</includeModules>
<excludes>
<!-- TODO remove after release -->
<exclude>software.amazon.awssdk.enhanced.dynamodb.extensions.annotations.DynamoDbVersionAttribute#incrementBy()</exclude>
<exclude>software.amazon.awssdk.enhanced.dynamodb.extensions.annotations.DynamoDbVersionAttribute#startAt()</exclude>
<exclude>*.internal.*</exclude>
<exclude>software.amazon.awssdk.thirdparty.*</exclude>
<exclude>software.amazon.awssdk.regions.*</exclude>
<!-- TODO remove after release -->
<exclude>software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor</exclude>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

japicmp is flagging this as a breaking change, however this is an SdkProtectedApi class, and the removal is of override methods whose implementation was not supported server side.

</excludes>

<ignoreMissingOldVersion>true</ignoreMissingOldVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,15 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.Test;
import org.slf4j.MDC;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;

Expand Down Expand Up @@ -65,182 +56,4 @@ public void traceIdInterceptorIsEnabled() {
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build());

client.allTypes().join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();
assertThat(requests).hasSize(3);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.thenRun(() -> {
client.allTypes().join();
})
.join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(400).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor() {
private boolean hasThrown = false;

@Override
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
if (!hasThrown) {
hasThrown = true;
throw new RuntimeException("failing in pre execution");
}
}
};

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.overrideConfiguration(o -> o.addExecutionInterceptor(throwingInterceptor))
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(1);
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}
}

Loading