Skip to content

Add support for lambda traces via threadLocal #6295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package software.amazon.awssdk.awscore.interceptor;

import java.util.Optional;
import org.slf4j.MDC;
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.SdkHttpRequest;
Expand All @@ -32,27 +34,57 @@
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
private static final ExecutionAttribute<String> TRACE_ID = new ExecutionAttribute<>("TraceId");

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
String traceId = MDC.get(CONCURRENT_TRACE_ID_KEY);
if (traceId != null) {
executionAttributes.putAttribute(TRACE_ID, traceId);
}
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
Optional<String> traceIdHeader = traceIdHeader(context);
if (!traceIdHeader.isPresent()) {
Optional<String> lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
Optional<String> traceId = traceId();
Optional<String> lambdaFunctionName = lambdaFunctionNameEnvironmentVariable();
Optional<String> traceId = traceId(executionAttributes);

if (lambdafunctionName.isPresent() && traceId.isPresent()) {
if (lambdaFunctionName.isPresent() && traceId.isPresent()) {
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
}
}

return context.httpRequest();
}

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
saveTraceId(executionAttributes);
}

private static void saveTraceId(ExecutionAttributes executionAttributes) {
String traceId = executionAttributes.getAttribute(TRACE_ID);
if (traceId != null) {
MDC.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
}
}

private Optional<String> traceIdHeader(Context.ModifyHttpRequest context) {
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
}

private Optional<String> traceId() {
private Optional<String> traceId(ExecutionAttributes executionAttributes) {
Optional<String> traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
if (traceId.isPresent()) {
return traceId;
}
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Properties;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.MDC;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
Expand Down Expand Up @@ -111,6 +112,78 @@ public void headerNotAddedIfNoTraceIdEnvVar() {
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithMdc_shouldAddTraceIdHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);
Context.ModifyHttpRequest context = context();

SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

@Test
public void modifyHttpRequest_whenMultiConcurrencyModeWithBothMdcAndSystemProperty_shouldUseMdcValue() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");

MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");
Properties props = System.getProperties();
props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
props.remove("com.amazonaws.xray.traceHeader");
}
});
}

@Test
public void modifyHttpRequest_whenNotInLambdaEnvironmentWithMdc_shouldNotAddHeader() {
EnvironmentVariableHelper.run(env -> {
resetRelevantEnvVars(env);

MDC.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");

try {
TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
ExecutionAttributes executionAttributes = new ExecutionAttributes();

interceptor.beforeExecution(null, executionAttributes);

Context.ModifyHttpRequest context = context();
SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);

assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
} finally {
MDC.remove("AWS_LAMBDA_X_TRACE_ID");
}
});
}

private Context.ModifyHttpRequest context() {
return context(SdkHttpRequest.builder()
.uri(URI.create("https://localhost"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.Test;
import org.slf4j.MDC;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;

Expand Down Expand Up @@ -56,4 +65,182 @@ public void traceIdInterceptorIsEnabled() {
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(500).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build());

client.allTypes().join();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a test sending a request in the result future:

client.allTypes()
      .thenRun(() -> client.allTypes() /* Did this request have the right trace ID? */)

I suspect we'll need to plumb the thread local context to there as well (and clear it, so that reuses of that thread don't have it set).


List<SdkHttpRequest> requests = mockHttpClient.getRequests();
assertThat(requests).hasSize(3);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.thenRun(() -> {
client.allTypes().join();
})
.join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(400).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build(),
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(2);

assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");
assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}

@Test
public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution() {
EnvironmentVariableHelper.run(env -> {
env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
MDC.put("AWS_LAMBDA_X_TRACE_ID", "mdc-trace-123");

ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor() {
private boolean hasThrown = false;

@Override
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
if (!hasThrown) {
hasThrown = true;
throw new RuntimeException("failing in pre execution");
}
}
};

try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
.region(Region.US_WEST_2)
.credentialsProvider(AnonymousCredentialsProvider.create())
.overrideConfiguration(o -> o.addExecutionInterceptor(throwingInterceptor))
.httpClient(mockHttpClient)
.build()) {

mockHttpClient.stubResponses(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream("{}")))
.build()
);

client.allTypes()
.exceptionally(throwable -> {
client.allTypes().join();
return null;
}).join();

List<SdkHttpRequest> requests = mockHttpClient.getRequests();

assertThat(requests).hasSize(1);
assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("mdc-trace-123");

} finally {
MDC.clear();
}
});
}
}

Loading