1717
1818import static org .assertj .core .api .Assertions .assertThat ;
1919
20+ import java .util .List ;
2021import org .junit .jupiter .api .Test ;
2122import software .amazon .awssdk .auth .credentials .AnonymousCredentialsProvider ;
2223import software .amazon .awssdk .awscore .interceptor .TraceIdExecutionInterceptor ;
24+ import software .amazon .awssdk .core .interceptor .Context ;
25+ import software .amazon .awssdk .core .interceptor .ExecutionAttributes ;
26+ import software .amazon .awssdk .core .interceptor .ExecutionInterceptor ;
2327import software .amazon .awssdk .http .AbortableInputStream ;
2428import software .amazon .awssdk .http .HttpExecuteResponse ;
29+ import software .amazon .awssdk .http .SdkHttpRequest ;
2530import software .amazon .awssdk .http .SdkHttpResponse ;
2631import software .amazon .awssdk .regions .Region ;
32+ import software .amazon .awssdk .services .protocolrestjson .ProtocolRestJsonAsyncClient ;
2733import software .amazon .awssdk .services .protocolrestjson .ProtocolRestJsonClient ;
2834import software .amazon .awssdk .testutils .EnvironmentVariableHelper ;
35+ import software .amazon .awssdk .testutils .service .http .MockAsyncHttpClient ;
2936import software .amazon .awssdk .testutils .service .http .MockSyncHttpClient ;
3037import software .amazon .awssdk .utils .StringInputStream ;
38+ import software .amazon .awssdk .utilslite .SdkInternalThreadLocal ;
3139
3240/**
3341 * Verifies that the {@link TraceIdExecutionInterceptor} is actually wired up for AWS services.
@@ -56,4 +64,181 @@ public void traceIdInterceptorIsEnabled() {
5664 }
5765 });
5866 }
59- }
67+
68+ @ Test
69+ public void traceIdInterceptorPreservesTraceIdAcrossRetries () {
70+ EnvironmentVariableHelper .run (env -> {
71+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
72+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
73+
74+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient ();
75+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient .builder ()
76+ .region (Region .US_WEST_2 )
77+ .credentialsProvider (AnonymousCredentialsProvider .create ())
78+ .httpClient (mockHttpClient )
79+ .build ()) {
80+
81+ mockHttpClient .stubResponses (
82+ HttpExecuteResponse .builder ()
83+ .response (SdkHttpResponse .builder ().statusCode (500 ).build ())
84+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
85+ .build (),
86+ HttpExecuteResponse .builder ()
87+ .response (SdkHttpResponse .builder ().statusCode (500 ).build ())
88+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
89+ .build (),
90+ HttpExecuteResponse .builder ().response (SdkHttpResponse .builder ().statusCode (200 ).build ())
91+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
92+ .build ());
93+
94+ client .allTypes ().join ();
95+
96+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
97+ assertThat (requests ).hasSize (3 );
98+
99+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
100+ assertThat (requests .get (1 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
101+ assertThat (requests .get (2 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
102+
103+ } finally {
104+ SdkInternalThreadLocal .clear ();
105+ }
106+ });
107+ }
108+
109+ @ Test
110+ public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures () {
111+ EnvironmentVariableHelper .run (env -> {
112+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
113+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
114+
115+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient ();
116+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient .builder ()
117+ .region (Region .US_WEST_2 )
118+ .credentialsProvider (AnonymousCredentialsProvider .create ())
119+ .httpClient (mockHttpClient )
120+ .build ()) {
121+
122+ mockHttpClient .stubResponses (
123+ HttpExecuteResponse .builder ()
124+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
125+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
126+ .build (),
127+ HttpExecuteResponse .builder ()
128+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
129+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
130+ .build ()
131+ );
132+
133+ client .allTypes ()
134+ .thenRun (() -> {
135+ client .allTypes ().join ();
136+ })
137+ .join ();
138+
139+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
140+
141+ assertThat (requests ).hasSize (2 );
142+
143+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
144+ assertThat (requests .get (1 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
145+
146+ } finally {
147+ SdkInternalThreadLocal .clear ();
148+ }
149+ });
150+ }
151+
152+ @ Test
153+ public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures () {
154+ EnvironmentVariableHelper .run (env -> {
155+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
156+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
157+
158+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient ();
159+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient .builder ()
160+ .region (Region .US_WEST_2 )
161+ .credentialsProvider (AnonymousCredentialsProvider .create ())
162+ .httpClient (mockHttpClient )
163+ .build ()) {
164+
165+ mockHttpClient .stubResponses (
166+ HttpExecuteResponse .builder ()
167+ .response (SdkHttpResponse .builder ().statusCode (400 ).build ())
168+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
169+ .build (),
170+ HttpExecuteResponse .builder ()
171+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
172+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
173+ .build ()
174+ );
175+
176+ client .allTypes ()
177+ .exceptionally (throwable -> {
178+ client .allTypes ().join ();
179+ return null ;
180+ }).join ();
181+
182+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
183+
184+ assertThat (requests ).hasSize (2 );
185+
186+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
187+ assertThat (requests .get (1 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
188+
189+ } finally {
190+ SdkInternalThreadLocal .clear ();
191+ }
192+ });
193+ }
194+
195+ @ Test
196+ public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution () {
197+ EnvironmentVariableHelper .run (env -> {
198+ env .set ("AWS_LAMBDA_FUNCTION_NAME" , "foo" );
199+ SdkInternalThreadLocal .put ("AWS_LAMBDA_X_TRACE_ID" , "SdkInternalThreadLocal-trace-123" );
200+
201+ ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor () {
202+ private boolean hasThrown = false ;
203+
204+ @ Override
205+ public void beforeMarshalling (Context .BeforeMarshalling context , ExecutionAttributes executionAttributes ) {
206+ if (!hasThrown ) {
207+ hasThrown = true ;
208+ throw new RuntimeException ("failing in pre execution" );
209+ }
210+ }
211+ };
212+
213+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient ();
214+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient .builder ()
215+ .region (Region .US_WEST_2 )
216+ .credentialsProvider (AnonymousCredentialsProvider .create ())
217+ .overrideConfiguration (o -> o .addExecutionInterceptor (throwingInterceptor ))
218+ .httpClient (mockHttpClient )
219+ .build ()) {
220+
221+ mockHttpClient .stubResponses (
222+ HttpExecuteResponse .builder ()
223+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
224+ .responseBody (AbortableInputStream .create (new StringInputStream ("{}" )))
225+ .build ()
226+ );
227+
228+ client .allTypes ()
229+ .exceptionally (throwable -> {
230+ client .allTypes ().join ();
231+ return null ;
232+ }).join ();
233+
234+ List <SdkHttpRequest > requests = mockHttpClient .getRequests ();
235+
236+ assertThat (requests ).hasSize (1 );
237+ assertThat (requests .get (0 ).firstMatchingHeader ("X-Amzn-Trace-Id" )).hasValue ("SdkInternalThreadLocal-trace-123" );
238+
239+ } finally {
240+ SdkInternalThreadLocal .clear ();
241+ }
242+ });
243+ }
244+ }
0 commit comments