Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.utils.SystemSetting;
import software.amazon.awssdk.utils.ThreadStorage;

/**
* The {@code TraceIdExecutionInterceptor} copies the trace details to the {@link #TRACE_ID_HEADER} header, assuming we seem to
Expand All @@ -32,27 +34,57 @@
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
private static final ExecutionAttribute<String> TRACE_ID = new ExecutionAttribute<>("TraceId");

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
String traceId = ThreadStorage.get(CONCURRENT_TRACE_ID_KEY);
if (traceId != null) {
executionAttributes.putAttribute(TRACE_ID, traceId);
}
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
Optional<String> traceIdHeader = traceIdHeader(context);
if (!traceIdHeader.isPresent()) {
Optional<String> lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
Optional<String> traceId = traceId();
Optional<String> traceId = traceId(executionAttributes);

if (lambdafunctionName.isPresent() && traceId.isPresent()) {
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
}
}

return context.httpRequest();
}

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

private static void saveTraceId(ExecutionAttributes executionAttributes) {
String traceId = executionAttributes.getAttribute(TRACE_ID);
if (traceId != null) {
ThreadStorage.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
}
}

private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
}

private Optional<String> traceId() {
private Optional<String> traceId(ExecutionAttributes executionAttributes) {
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
if (traceId.isPresent()) {
return traceId;
}
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
import software.amazon.awssdk.utils.ThreadStorage;

public class TraceIdExecutionInterceptorTest {
@Test
Expand Down Expand Up @@ -111,6 +112,78 @@ public void headerNotAddedIfNoTraceIdEnvVar() {
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithThreadStorage_shouldAddTraceIdHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);
Context.ModifyHttpRequest context = context();

SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
} finally {
ThreadStorage.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithBothThreadStorageAndSystemProperty_shouldUseThreadStorageValue() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");

ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");
Properties props = System.getProperties();
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
} finally {
ThreadStorage.remove("AWS_LAMBDA_X_TRACE_ID");
props.remove("com.amazonaws.xray.traceHeader");
}
});
}

@Test
public void modifyHttpRequest_whenNotInLambdaEnvironmentWithThreadStorage_shouldNotAddHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);

ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
} finally {
ThreadStorage.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

private Context.ModifyHttpRequest context() {
return context(SdkHttpRequest.builder()
.uri(URI.create("https://localhost"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;
import software.amazon.awssdk.utils.ThreadStorage;

/**
* Verifies that the {@link TraceIdExecutionInterceptor} is actually wired up for AWS services.
Expand Down Expand Up @@ -56,4 +64,182 @@ public void traceIdInterceptorIsEnabled() {
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build());

client.allTypes().join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();
assertThat(requests).hasSize(3);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");

} finally {
ThreadStorage.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.thenRun(() -> {
client.allTypes().join();
})
.join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");

} finally {
ThreadStorage.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(400).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");

} finally {
ThreadStorage.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
ThreadStorage.put("AWS_LAMBDA_X_TRACE_ID", "ThreadStorage-trace-123");

ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor() {
private boolean hasThrown = false;

@Override
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
if (!hasThrown) {
hasThrown = true;
throw new RuntimeException("failing in pre execution");
}
}
};

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.overrideConfiguration(o -> o.addExecutionInterceptor(throwingInterceptor))
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(1);
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("ThreadStorage-trace-123");

} finally {
ThreadStorage.clear();
}
});
}
}

Loading
Loading